From 3da893a4f43cd703e2cb50318c5a6ad14907ad5f Mon Sep 17 00:00:00 2001 From: Giulio Eulisse <10544+ktf@users.noreply.github.com> Date: Mon, 1 Jun 2026 18:31:00 +0200 Subject: [PATCH 1/2] Avoid having intermediary vectors for the NN-based TPC PID --- Common/Tools/PID/pidTPCModule.h | 269 +++++++++---- Tools/ML/CMakeLists.txt | 2 +- Tools/ML/model.h | 485 +++++++++++++++++++++++ dependencies/O2PhysicsDependencies.cmake | 1 + 4 files changed, 685 insertions(+), 72 deletions(-) diff --git a/Common/Tools/PID/pidTPCModule.h b/Common/Tools/PID/pidTPCModule.h index 24c7683b70c..8de26239896 100644 --- a/Common/Tools/PID/pidTPCModule.h +++ b/Common/Tools/PID/pidTPCModule.h @@ -414,6 +414,7 @@ class pidTPCModule network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0)); std::vector dummyInput(network.getNumInputNodes(), 1.); network.evalModel(dummyInput); /// Init the model evaluations + setupColumnInputNetwork(); LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, and NN-Version {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]); } else { LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available"; @@ -427,6 +428,7 @@ class pidTPCModule network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value); std::vector dummyInput(network.getNumInputNodes(), 1.); network.evalModel(dummyInput); // This is an initialisation and might reduce the overhead of the model + setupColumnInputNetwork(); } } else { return; @@ -438,6 +440,110 @@ class pidTPCModule } } // end init + //__________________________________________________ + void setupColumnInputNetwork() + { + using PI = o2::ml::OnnxModel::PreprocInput; + using PF = o2::ml::OnnxModel::PreprocFeature; + const int nFeat = network.getNumInputNodes(); // # network features (6..9), original model + + // Raw graph inputs (this order defines the tensor feeding order in + // createNetworkPrediction). All per-track columns are wrapped zero-copy from + // the Arrow buffers; nclNorm/hrDivisor/hrFallback are per-DF runtime scalars. + std::vector in; + in.push_back({"tpcInnerParam", PI::Type::TrackFloat}); + in.push_back({"tgl", PI::Type::TrackFloat}); + in.push_back({"signed1Pt", PI::Type::TrackFloat}); + in.push_back({"mass", PI::Type::ScalarFloat}); + in.push_back({"collisionId", PI::Type::TrackInt32}); + in.push_back({"multArray", PI::Type::CollisionFloat}); + in.push_back({"nclNorm", PI::Type::ScalarFloat}); + in.push_back({"nclsFindable", PI::Type::TrackUint8}); + in.push_back({"nclsFMF", PI::Type::TrackInt8}); + if (nFeat >= 7) { + in.push_back({"occArray", PI::Type::CollisionFloat}); + } + if (nFeat >= 8) { + in.push_back({"hrArray", PI::Type::CollisionFloat}); + in.push_back({"hrDivisor", PI::Type::ScalarFloat}); + in.push_back({"hrFallback", PI::Type::ScalarFloat}); + } + if (nFeat >= 9) { + in.push_back({"phi", PI::Type::TrackFloat}); + } + in.push_back({"validMask", PI::Type::TrackBool}); + + // Per-feature preprocessing (exactly nFeat entries, in the training order). + std::vector feat; + { + PF f; + f.op = PF::Op::Passthrough; + f.a = "tpcInnerParam"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::Passthrough; + f.a = "tgl"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::Passthrough; + f.a = "signed1Pt"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::BroadcastScalar; + f.a = "mass"; + f.shapeRef = "collisionId"; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::GatherNormWhere; + f.a = "multArray"; + f.b = "collisionId"; + f.c = {11000.f, 1.f, 0.f}; + feat.push_back(f); + } + { + PF f; + f.op = PF::Op::NClsSqrtRecip; + f.a = "nclsFindable"; + f.b = "nclsFMF"; + f.scaleInput = "nclNorm"; + feat.push_back(f); + } + if (nFeat >= 7) { + PF f; + f.op = PF::Op::GatherNormWhere; + f.a = "occArray"; + f.b = "collisionId"; + f.c = {60000.f, 1.f, 0.f}; + feat.push_back(f); + } + if (nFeat >= 8) { + PF f; + f.op = PF::Op::GatherNormWhere; + f.a = "hrArray"; + f.b = "collisionId"; + f.scaleInput = "hrDivisor"; + f.fallbackInput = "hrFallback"; + feat.push_back(f); + } + if (nFeat >= 9) { + PF f; + f.op = PF::Op::Mod2; + f.a = "phi"; + f.c = {2.f * static_cast(M_PI), 2.f * static_cast(M_PI), static_cast(M_PI) / 9.0f}; + feat.push_back(f); + } + + network.setupColumnInputs(in, feat, "validMask"); + } + //__________________________________________________ template std::vector createNetworkPrediction(TCCDB& ccdb, soa::Join const& collisions, M const& mults, T const& tracks, B const& bcs, const size_t size) @@ -489,6 +595,7 @@ class pidTPCModule network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0)); std::vector dummyInput(network.getNumInputNodes(), 1.); network.evalModel(dummyInput); + setupColumnInputNetwork(); LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, NN-Version number {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]); } else { LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available"; @@ -497,19 +604,14 @@ class pidTPCModule } // Defining some network parameters - int input_dimensions = network.getNumInputNodes(); + const int nFeat = network.getNumFeatures(); int output_dimensions = network.getNumOutputNodes(); - const uint64_t track_prop_size = input_dimensions * size; const uint64_t prediction_size = output_dimensions * size; network_prediction = std::vector(prediction_size * 9); // For each mass hypotheses const float nNclNormalization = response->GetNClNormalization(); float duration_network = 0; - std::vector track_properties(track_prop_size); - uint64_t counter_track_props = 0; - int loop_counter = 0; - // To load the Hadronic rate once for each collision float hadronicRateBegin = 0.; std::vector hadronicRateForCollision(collisions.size(), 0.0f); @@ -530,88 +632,113 @@ class pidTPCModule hadronicRateBegin = 0.0f; } - // Filling a std::vector to be evaluated by the network - // Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector static constexpr int NParticleTypes = 9; constexpr int ExpectedInputDimensionsNNV2 = 7; constexpr int ExpectedInputDimensionsNNV3 = 8; constexpr int ExpectedInputDimensionsNNV4 = 9; - constexpr auto NetworkVersionV2 = "2"; - constexpr auto NetworkVersionV3 = "3"; - constexpr auto NetworkVersionV4 = "4"; - for (int j = 0; j < NParticleTypes; j++) { // Loop over particle number for which network correction is used - for (auto const& trk : tracks) { - if (!trk.hasTPC()) { - continue; - } - if (pidTPCopts.skipTPCOnly) { - if (!trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) { - continue; - } - } - track_properties[counter_track_props] = trk.tpcInnerParam(); - track_properties[counter_track_props + 1] = trk.tgl(); - track_properties[counter_track_props + 2] = trk.signed1Pt(); - track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[j]; - track_properties[counter_track_props + 4] = trk.has_collision() ? mults[trk.collisionId()] / 11000. : 1.; - track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound()); - if (input_dimensions == ExpectedInputDimensionsNNV2 && networkVersion == NetworkVersionV2) { - track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.; - } - if (input_dimensions == ExpectedInputDimensionsNNV3 && networkVersion == NetworkVersionV3) { - track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.; - if (trk.has_collision()) { - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.; - } - } else { - // asign Hadronic Rate at beginning of run if track does not belong to a collision - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateBegin / 50.; - } - } + + const float hadronicRateDivisor = (collsys == CollisionSystemType::kCollSyspp) ? 1500.f : 50.f; + + // Per-collision arrays (O(nColl)); gathered per track inside the model via the + // collisionId column, then normalised in-graph. + const int64_t nColl = static_cast(collisions.size()); + std::vector multArray(nColl); + std::vector occArray(nFeat >= ExpectedInputDimensionsNNV2 ? nColl : 0); + { + int64_t c = 0; + for (const auto& col : collisions) { + multArray[c] = static_cast(mults[c]); + if (nFeat >= ExpectedInputDimensionsNNV2) { + occArray[c] = col.ft0cOccupancyInTimeRange(); } + ++c; + } + } - if (input_dimensions == ExpectedInputDimensionsNNV4 && networkVersion == NetworkVersionV4) { - track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.; - if (trk.has_collision()) { - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.; - } - } else { - // asign Hadronic Rate at beginning of run if track does not belong to a collision - if (collsys == CollisionSystemType::kCollSyspp) { - track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.; - } else { - track_properties[counter_track_props + 7] = hadronicRateBegin / 50.; - } - } - track_properties[counter_track_props + 8] = std::fmod(std::fmod(trk.phi(), 2 * M_PI) + 2 * M_PI, M_PI / 9.0); + // Raw per-track Arrow column buffers (zero-copy; one chunk per DataFrame). + auto arrowTable = tracks.asArrowTable(); + auto chunk0 = [&](const char* name) -> std::shared_ptr { + const int idx = arrowTable->schema()->GetFieldIndex(name); + if (idx < 0) { + LOG(fatal) << "createNetworkPrediction: column '" << name << "' not found in tracks table"; + } + auto col = arrowTable->column(idx); + if (col->num_chunks() != 1) { + LOG(fatal) << "createNetworkPrediction: column '" << name << "' has " << col->num_chunks() + << " chunks; a single chunk per DataFrame is required for zero-copy input"; + } + return col->chunk(0); + }; + const int64_t nTrk = static_cast(tracks.size()); + const float* pTpcInner = std::static_pointer_cast(chunk0("fTPCInnerParam"))->raw_values(); + const float* pTgl = std::static_pointer_cast(chunk0("fTgl"))->raw_values(); + const float* pSigned1Pt = std::static_pointer_cast(chunk0("fSigned1Pt"))->raw_values(); + const int32_t* pCollId = std::static_pointer_cast(chunk0("fIndexCollisions"))->raw_values(); + const uint8_t* pFindable = std::static_pointer_cast(chunk0("fTPCNClsFindable"))->raw_values(); + const int8_t* pFMF = std::static_pointer_cast(chunk0("fTPCNClsFindableMinusFound"))->raw_values(); + const float* pPhi = (nFeat >= ExpectedInputDimensionsNNV4) + ? std::static_pointer_cast(chunk0("fPhi"))->raw_values() + : nullptr; + + // Single boolean mask of the tracks the network runs on; the model Compress'es + // to exactly these rows so the output is compact and the consumer's + // count_tracks indexing is unchanged. Condition matches process()'s counter. + std::vector validMask(nTrk); + { + int64_t t = 0; + for (auto const& trk : tracks) { + bool valid = trk.hasTPC(); + if (valid && pidTPCopts.skipTPCOnly && !trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) { + valid = false; } - counter_track_props += input_dimensions; + validMask[t++] = valid ? 1 : 0; + } + } + + auto memInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + const int64_t one = 1; + float massVal = 0.f; + float nclNormVal = nNclNormalization; + float hrDivisorVal = hadronicRateDivisor; + float hrFallbackVal = hadronicRateBegin / hadronicRateDivisor; + + // Evaluate once per mass hypothesis; only the mass scalar input changes. + for (int j = 0; j < NParticleTypes; j++) { + massVal = o2::track::pid_constants::sMasses[j]; + + std::vector inputTensors; + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pTpcInner), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pTgl), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pSigned1Pt), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &massVal, 1, &one, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pCollId), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, multArray.data(), nColl, &nColl, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &nclNormVal, 1, &one, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pFindable), nTrk, &nTrk, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pFMF), nTrk, &nTrk, 1)); + if (nFeat >= ExpectedInputDimensionsNNV2) { + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, occArray.data(), nColl, &nColl, 1)); + } + if (nFeat >= ExpectedInputDimensionsNNV3) { + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, hadronicRateForCollision.data(), nColl, &nColl, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &hrDivisorVal, 1, &one, 1)); + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, &hrFallbackVal, 1, &one, 1)); } + if (nFeat >= ExpectedInputDimensionsNNV4) { + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, const_cast(pPhi), nTrk, &nTrk, 1)); + } + inputTensors.emplace_back(Ort::Value::CreateTensor(memInfo, reinterpret_cast(validMask.data()), nTrk, &nTrk, 1)); auto start_network_eval = std::chrono::high_resolution_clock::now(); - float* output_network = network.evalModel(track_properties); + float* output_network = network.evalModel(inputTensors); auto stop_network_eval = std::chrono::high_resolution_clock::now(); duration_network += std::chrono::duration>(stop_network_eval - start_network_eval).count(); for (uint64_t k = 0; k < prediction_size; k += output_dimensions) { for (int l = 0; l < output_dimensions; l++) { - network_prediction[k + l + prediction_size * loop_counter] = output_network[k + l]; + network_prediction[k + l + prediction_size * j] = output_network[k + l]; } } - - counter_track_props = 0; - loop_counter += 1; } - track_properties.clear(); - auto stop_network_total = std::chrono::high_resolution_clock::now(); LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval ONNX): " << duration_network / (size * 9) << "ns ; Total time (eval ONNX): " << duration_network / 1000000000 << " s"; LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval + overhead): " << std::chrono::duration>(stop_network_total - start_network_total).count() / (size * 9) << "ns ; Total time (eval + overhead): " << std::chrono::duration>(stop_network_total - start_network_total).count() / 1000000000 << " s"; diff --git a/Tools/ML/CMakeLists.txt b/Tools/ML/CMakeLists.txt index b95c108584a..4dde4095cb6 100644 --- a/Tools/ML/CMakeLists.txt +++ b/Tools/ML/CMakeLists.txt @@ -11,5 +11,5 @@ o2physics_add_library(MLCore SOURCES model.cxx - PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime + PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime ONNX::onnx_proto ) diff --git a/Tools/ML/model.h b/Tools/ML/model.h index 3be08e72fa9..9bb434ec59f 100644 --- a/Tools/ML/model.h +++ b/Tools/ML/model.h @@ -22,6 +22,7 @@ #include +#include #include #include @@ -29,8 +30,11 @@ #include #include #include +#include +#include #include #include +#include #include #include @@ -132,6 +136,483 @@ class OnnxModel mSession.reset(new Ort::Session{*mEnv, modelPath.c_str(), sessionOptions}); } + // Declaration of a raw graph input fed directly from an Arrow buffer. + struct PreprocInput { + enum class Type { TrackFloat, // per-track float column [N] + TrackInt32, // per-track int32 column [N] + TrackUint8, // per-track uint8 column [N] + TrackInt8, // per-track int8 column [N] + TrackBool, // per-track bool mask [N] + CollisionFloat,// per-collision float array [C] + ScalarFloat }; // single scalar (e.g. mass) [1] + std::string name; + Type type; + }; + + // Preprocessing recipe for one network feature (produces a [N] float tensor + // that feeds column i of the decomposed first layer). + struct PreprocFeature { + enum class Op { + Passthrough, // feature = a (a: TrackFloat) + BroadcastScalar, // feature = Expand(a, shape(shapeRef)) (a: ScalarFloat) + NClsSqrtRecip, // feature = Sqrt(c0 / (float(a) - float(b))) (a,b: Track int cols) + Mod2, // feature = Mod(Mod(a, c0) + c1, c2) (a: TrackFloat) + GatherNormWhere // feature = Where(b<0, fb, Gather(a, b) / c0) (a: CollisionFloat, b: TrackInt32) + }; + Op op; + std::string a; // primary input + std::string b; // secondary input (NCls: 2nd col; Gather: index) + std::string shapeRef; // BroadcastScalar: [N] input to size the Expand + std::string fallbackInput; // GatherNormWhere: scalar input for the b<0 fallback; "" => use c[1] + std::string scaleInput; // NCls: numerator scalar input; Gather: divisor scalar input; "" => use c[0] + std::array c{}; // op constants + }; + + // Rebuild the model from scratch so the network reads its raw Arrow inputs + // directly and performs all preprocessing + the first linear layer inside the + // graph. Each feature column is produced by a small preprocessing subgraph + // (`features`) from the raw inputs (`inputDefs`), then the first linear layer is + // decomposed + // layer0 = X @ W + b = sum_i (feat_i[N,1] @ W_row_i[1,H]) + b + // so no [N, K] interleaving / Concat buffer is ever materialised. The original + // layers 1..N are copied verbatim on top of the decomposed layer-0 output. + // Building a fresh model (rather than augmenting the existing one) is required: + // the Model Editor can only add nodes, so the original layer-0 Gemm would + // otherwise remain and collide on the layer-0 output name. + // If maskInput is non-empty it names a bool [N] input; each feature is then + // Compress'd by it so the matmul runs only on the selected (valid) rows and the + // output is the compact set of selected tracks, in order. + void setupColumnInputs(const std::vector& inputDefs, + const std::vector& features, + const std::string& maskInput = "") + { + const int numFeatures = static_cast(features.size()); + if (numFeatures != mInputShapes[0][1]) { + LOG(fatal) << "setupColumnInputs: expected " << mInputShapes[0][1] << " features, got " << numFeatures; + return; + } + + onnx::ModelProto onnxModel; + { + std::ifstream ifs(modelPath, std::ios::binary); + if (!ifs || !onnxModel.ParseFromIstream(&ifs)) { + LOG(fatal) << "setupColumnInputs: failed to parse ONNX model from " << modelPath; + return; + } + } + const auto& og = onnxModel.graph(); + + int opset = 0; + for (const auto& oi : onnxModel.opset_import()) { + if (oi.domain().empty()) { + opset = static_cast(oi.version()); + } + } + if (opset == 0) { + opset = 13; // Unsqueeze with axes as input requires opset >= 13 + } + + auto findInit = [&](const std::string& name) -> const onnx::TensorProto* { + for (int i = 0; i < og.initializer_size(); ++i) { + if (og.initializer(i).name() == name) { + return &og.initializer(i); + } + } + return nullptr; + }; + auto tensorFloats = [&](const onnx::TensorProto* t, int64_t n) { + std::vector v(n); + if (t->raw_data().size() > 0) { + std::memcpy(v.data(), t->raw_data().data(), n * sizeof(float)); + } else { + for (int64_t i = 0; i < n; ++i) { + v[i] = t->float_data(i); + } + } + return v; + }; + + // --- locate the first linear layer (layer 0) --- + const onnx::NodeProto* first = nullptr; + for (int i = 0; i < og.node_size(); ++i) { + const auto& n = og.node(i); + if (n.op_type() == "Gemm" || n.op_type() == "MatMul") { + first = &n; + break; + } + } + if (!first) { + LOG(fatal) << "setupColumnInputs: no Gemm/MatMul layer found in model"; + return; + } + const std::string layer0Out = first->output(0); // pre-activation output we must reproduce + + const onnx::TensorProto* wT = findInit(first->input(1)); + if (!wT || wT->dims_size() != 2) { + LOG(fatal) << "setupColumnInputs: first-layer weight initializer not found or not 2D"; + return; + } + bool transB = false; + if (first->op_type() == "Gemm") { + for (int i = 0; i < first->attribute_size(); ++i) { + if (first->attribute(i).name() == "transB") { + transB = (first->attribute(i).i() != 0); + } + } + } + const int K = transB ? static_cast(wT->dims(1)) : static_cast(wT->dims(0)); + const int H = transB ? static_cast(wT->dims(0)) : static_cast(wT->dims(1)); + if (K != numFeatures) { + LOG(fatal) << "setupColumnInputs: first-layer K=" << K << " != numFeatures=" << numFeatures; + return; + } + const std::vector wData = tensorFloats(wT, static_cast(K) * H); + + std::vector bData; + if (first->op_type() == "Gemm" && first->input_size() >= 3 && !first->input(2).empty()) { + if (const onnx::TensorProto* bT = findInit(first->input(2))) { + bData = tensorFloats(bT, H); + } + } + + // --- build the new graph --- + Ort::AllocatorWithDefaultOptions allocator; + Ort::Graph graph; + + auto addFloatInit = [&](const std::string& name, const std::vector& data, + const std::vector& shape) { + auto val = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + std::memcpy(val.GetTensorMutableData(), data.data(), data.size() * sizeof(float)); + graph.AddInitializer(name, val, false); + }; + + // axes = {1} for Unsqueeze (opset >= 13 takes axes as an input) + { + std::vector shape = {1}; + auto val = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + val.GetTensorMutableData()[0] = 1; + graph.AddInitializer("_col_axes", val, false); + } + + auto addScalarF = [&](const std::string& name, float v) { + const std::vector shape = {1}; + auto val = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + val.GetTensorMutableData()[0] = v; + graph.AddInitializer(name, val, false); + }; + addScalarF("_decZeroF", 0.f); // unused placeholder kept for symmetry + { + const std::vector shape = {1}; + auto val = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + val.GetTensorMutableData()[0] = 0; + graph.AddInitializer("_decZeroI32", val, false); + } + + // --- raw graph inputs (wrapped zero-copy from Arrow at inference time) --- + std::vector inputs; + inputs.reserve(inputDefs.size()); + std::vector rawInputNames; + rawInputNames.reserve(inputDefs.size()); + for (const auto& pin : inputDefs) { + ONNXTensorElementDataType et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + std::vector dims = {-1}; + std::vector sym = {"N"}; + switch (pin.type) { + case PreprocInput::Type::TrackFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; break; + case PreprocInput::Type::TrackInt32: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; break; + case PreprocInput::Type::TrackUint8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; break; + case PreprocInput::Type::TrackInt8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; break; + case PreprocInput::Type::TrackBool: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; break; + case PreprocInput::Type::CollisionFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; sym = {"C"}; break; + case PreprocInput::Type::ScalarFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; dims = {1}; sym = {""}; break; + } + Ort::TensorTypeAndShapeInfo tInfo(et, dims, &sym); + auto typeInfo = Ort::TypeInfo::CreateTensorInfo(tInfo.GetConst()); + inputs.emplace_back(pin.name, typeInfo.GetConst()); + rawInputNames.push_back(pin.name); + } + + // --- preprocessing subgraph: produce one [N] float feature tensor per column --- + const int64_t toFloat = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + const int64_t fmodFlag = 1; + std::vector featNames(K); + for (int i = 0; i < K; ++i) { + const auto& f = features[i]; + const std::string p = "_pp" + std::to_string(i); + switch (f.op) { + case PreprocFeature::Op::Passthrough: { + featNames[i] = f.a; // the raw input is already the [N] feature + break; + } + case PreprocFeature::Op::BroadcastScalar: { + Ort::Node shapeNode("Shape", "", p + "_shape", {f.shapeRef}, {p + "_shp"}); + graph.AddNode(shapeNode); + Ort::Node expandNode("Expand", "", p + "_expand", {f.a, p + "_shp"}, {p}); + graph.AddNode(expandNode); + featNames[i] = p; + break; + } + case PreprocFeature::Op::NClsSqrtRecip: { + std::vector ca; + ca.emplace_back("to", &toFloat, 1, ORT_OP_ATTR_INT); + Ort::Node castA("Cast", "", p + "_castA", {f.a}, {p + "_a"}, ca); + graph.AddNode(castA); + std::vector cb; + cb.emplace_back("to", &toFloat, 1, ORT_OP_ATTR_INT); + Ort::Node castB("Cast", "", p + "_castB", {f.b}, {p + "_b"}, cb); + graph.AddNode(castB); + Ort::Node subNode("Sub", "", p + "_sub", {p + "_a", p + "_b"}, {p + "_ncl"}); + graph.AddNode(subNode); + std::string numer = f.scaleInput; + if (numer.empty()) { + addScalarF(p + "_c0", f.c[0]); + numer = p + "_c0"; + } + Ort::Node divNode("Div", "", p + "_div", {numer, p + "_ncl"}, {p + "_recip"}); + graph.AddNode(divNode); + Ort::Node sqrtNode("Sqrt", "", p + "_sqrt", {p + "_recip"}, {p}); + graph.AddNode(sqrtNode); + featNames[i] = p; + break; + } + case PreprocFeature::Op::Mod2: { + addScalarF(p + "_c0", f.c[0]); + addScalarF(p + "_c1", f.c[1]); + addScalarF(p + "_c2", f.c[2]); + std::vector m1a; + m1a.emplace_back("fmod", &fmodFlag, 1, ORT_OP_ATTR_INT); + Ort::Node mod1("Mod", "", p + "_mod1", {f.a, p + "_c0"}, {p + "_m1"}, m1a); + graph.AddNode(mod1); + Ort::Node addNode("Add", "", p + "_add", {p + "_m1", p + "_c1"}, {p + "_m1a"}); + graph.AddNode(addNode); + std::vector m2a; + m2a.emplace_back("fmod", &fmodFlag, 1, ORT_OP_ATTR_INT); + Ort::Node mod2("Mod", "", p + "_mod2", {p + "_m1a", p + "_c2"}, {p}, m2a); + graph.AddNode(mod2); + featNames[i] = p; + break; + } + case PreprocFeature::Op::GatherNormWhere: { + // Gather tolerates the -1 ambiguous index (negative indexing); the Where + // below discards that value in favour of the fallback, so no Clip needed. + Ort::Node gatherNode("Gather", "", p + "_gather", {f.a, f.b}, {p + "_g"}); + graph.AddNode(gatherNode); + std::string denom = f.scaleInput; + if (denom.empty()) { + addScalarF(p + "_c0", f.c[0]); + denom = p + "_c0"; + } + Ort::Node divNode("Div", "", p + "_div", {p + "_g", denom}, {p + "_gn"}); + graph.AddNode(divNode); + Ort::Node lessNode("Less", "", p + "_less", {f.b, "_decZeroI32"}, {p + "_amb"}); + graph.AddNode(lessNode); + std::string fbName; + if (!f.fallbackInput.empty()) { + fbName = f.fallbackInput; // runtime scalar input + } else { + addScalarF(p + "_fb", f.c[1]); + fbName = p + "_fb"; + } + Ort::Node whereNode("Where", "", p + "_where", {p + "_amb", fbName, p + "_gn"}, {p}); + graph.AddNode(whereNode); + featNames[i] = p; + break; + } + } + } + + // --- decompose the first linear layer over the produced feature tensors --- + // feat_i [N] -> Unsqueeze [N,1] -> MatMul [N,1]x[1,H] -> partial_i [N,H] + for (int i = 0; i < K; ++i) { + std::vector row(H); + for (int j = 0; j < H; ++j) { + row[j] = transB ? wData[static_cast(j) * K + i] : wData[static_cast(i) * H + j]; + } + addFloatInit("_w_row_" + std::to_string(i), row, {1, static_cast(H)}); + + // Optionally drop the filtered-out rows before the matmul (compact output). + std::string decompIn = featNames[i]; + if (!maskInput.empty()) { + const int64_t axis0 = 0; + std::vector ca; + ca.emplace_back("axis", &axis0, 1, ORT_OP_ATTR_INT); + decompIn = "_cf_" + std::to_string(i); + Ort::Node compressNode("Compress", "", "_compress_" + std::to_string(i), + {featNames[i], maskInput}, {decompIn}, ca); + graph.AddNode(compressNode); + } + + const std::string unsq = "_col_unsq_" + std::to_string(i); + Ort::Node unsqNode("Unsqueeze", "", "_unsqueeze_" + std::to_string(i), + {decompIn, "_col_axes"}, {unsq}); + graph.AddNode(unsqNode); + Ort::Node matmulNode("MatMul", "", "_matmul_" + std::to_string(i), + {unsq, "_w_row_" + std::to_string(i)}, {"_partial_" + std::to_string(i)}); + graph.AddNode(matmulNode); + } + + // sum the partials; the final op must produce layer0Out + std::string acc = "_partial_0"; + for (int i = 1; i < K; ++i) { + const bool last = (i == K - 1); + const std::string out = (last && bData.empty()) ? layer0Out : ("_sum_" + std::to_string(i)); + Ort::Node addNode("Add", "", "_add_" + std::to_string(i), + {acc, "_partial_" + std::to_string(i)}, {out}); + graph.AddNode(addNode); + acc = out; + } + if (!bData.empty()) { + addFloatInit("_dec_bias", bData, {static_cast(H)}); + Ort::Node biasNode("Add", "", "_add_bias", {acc, "_dec_bias"}, {layer0Out}); + graph.AddNode(biasNode); + } else if (K == 1) { + Ort::Node idNode("Identity", "", "_id_layer0", {"_partial_0"}, {layer0Out}); + graph.AddNode(idNode); + } + + // --- copy original layers 1..N: every node reachable backwards from the graph + // outputs, except the replaced first layer (and any node, e.g. an input + // Cast, that fed only into it and is now dead) --- + std::set live; + for (int i = 0; i < og.output_size(); ++i) { + live.insert(og.output(i).name()); + } + std::vector keep(og.node_size(), false); + bool changed = true; + while (changed) { + changed = false; + for (int i = 0; i < og.node_size(); ++i) { + if (keep[i] || &og.node(i) == first) { + continue; + } + const auto& n = og.node(i); + bool produces = false; + for (int o = 0; o < n.output_size(); ++o) { + if (live.count(n.output(o))) { + produces = true; + break; + } + } + if (!produces) { + continue; + } + keep[i] = true; + changed = true; + for (int in = 0; in < n.input_size(); ++in) { + live.insert(n.input(in)); + } + } + } + + int copiedNodes = 0; + std::set neededInits; + for (int i = 0; i < og.node_size(); ++i) { + if (!keep[i]) { + continue; + } + const auto& n = og.node(i); + const std::vector ins(n.input().begin(), n.input().end()); + const std::vector outs(n.output().begin(), n.output().end()); + std::vector attrs; + for (int a = 0; a < n.attribute_size(); ++a) { + const auto& at = n.attribute(a); + switch (at.type()) { + case onnx::AttributeProto::FLOAT: { + const float v = at.f(); + attrs.emplace_back(at.name().c_str(), &v, 1, ORT_OP_ATTR_FLOAT); + break; + } + case onnx::AttributeProto::INT: { + const int64_t v = at.i(); + attrs.emplace_back(at.name().c_str(), &v, 1, ORT_OP_ATTR_INT); + break; + } + case onnx::AttributeProto::FLOATS: { + const std::vector v(at.floats().begin(), at.floats().end()); + attrs.emplace_back(at.name().c_str(), v.data(), static_cast(v.size()), ORT_OP_ATTR_FLOATS); + break; + } + case onnx::AttributeProto::INTS: { + const std::vector v(at.ints().begin(), at.ints().end()); + attrs.emplace_back(at.name().c_str(), v.data(), static_cast(v.size()), ORT_OP_ATTR_INTS); + break; + } + case onnx::AttributeProto::STRING: { + const std::string& s = at.s(); + attrs.emplace_back(at.name().c_str(), s.data(), static_cast(s.size()), ORT_OP_ATTR_STRING); + break; + } + default: + LOG(fatal) << "setupColumnInputs: unhandled attribute type " << at.type() << " on node " << n.name(); + return; + } + } + Ort::Node node(n.op_type(), n.domain(), + n.name().empty() ? ("_copy_" + std::to_string(i)) : n.name(), ins, outs, attrs); + graph.AddNode(node); + ++copiedNodes; + for (const auto& in : ins) { + if (findInit(in)) { + neededInits.insert(in); + } + } + } + + for (const auto& name : neededInits) { + const onnx::TensorProto* t = findInit(name); + std::vector shape(t->dims().begin(), t->dims().end()); + int64_t n = 1; + for (const auto d : shape) { + n *= d; + } + if (t->data_type() != onnx::TensorProto::FLOAT) { + LOG(fatal) << "setupColumnInputs: unsupported initializer dtype " << t->data_type() << " for " << name; + return; + } + addFloatInit(name, tensorFloats(t, n), shape); + } + + // graph outputs (copy name + tensor type/shape from the original model) + std::vector outputs; + for (int i = 0; i < og.output_size(); ++i) { + const auto& vp = og.output(i); + const auto& tt = vp.type().tensor_type(); + std::vector dims; + std::vector sym; + for (int d = 0; d < tt.shape().dim_size(); ++d) { + const auto& dd = tt.shape().dim(d); + if (dd.has_dim_value()) { + dims.push_back(dd.dim_value()); + sym.emplace_back(""); + } else { + dims.push_back(-1); + sym.push_back(dd.dim_param().empty() ? ("d" + std::to_string(d)) : dd.dim_param()); + } + } + Ort::TensorTypeAndShapeInfo tInfo(static_cast(tt.elem_type()), dims, &sym); + auto typeInfo = Ort::TypeInfo::CreateTensorInfo(tInfo.GetConst()); + outputs.emplace_back(vp.name(), typeInfo.GetConst()); + } + + graph.SetInputs(inputs); + graph.SetOutputs(outputs); + + // Declare the ONNX domain plus the ORT contrib domain: with graph + // optimizations enabled ORT fuses Gemm+activation into com.microsoft FusedGemm + // ops, which need an opset import for that domain in a Model-Editor-built model. + Ort::Model model({{std::string(), opset}, {std::string("com.microsoft"), 1}}); + model.AddGraph(graph); + mSession = std::make_shared(*mEnv, model, sessionOptions); + + mInputNames = rawInputNames; + mInputShapes.assign(rawInputNames.size(), std::vector{-1}); + mNumFeatures = K; + + LOG(info) << "setupColumnInputs: rebuilt model with " << rawInputNames.size() << " raw inputs -> " + << K << " preprocessed features, layer 0 decomposed (H=" << H << "), " + << copiedNodes << " downstream nodes copied"; + } + // Getters & Setters Ort::SessionOptions* getSessionOptions() { return &sessionOptions; } // For optimizations in post std::shared_ptr getSession() @@ -139,6 +620,9 @@ class OnnxModel return mSession; } int getNumInputNodes() const { return mInputShapes[0][1]; } + bool hasColumnInputs() const { return mInputShapes.size() > 1 || (mInputShapes.size() == 1 && mInputShapes[0].size() == 1); } + int getNumColumns() const { return static_cast(mInputNames.size()); } + int getNumFeatures() const { return mNumFeatures; } std::vector> getInputShapes() const { return mInputShapes; } int getNumOutputNodes() const { return mOutputShapes[0][1]; } uint64_t getValidityFrom() const { return validFrom; } @@ -154,6 +638,7 @@ class OnnxModel // Input & Output specifications of the loaded network std::vector mInputNames; std::vector> mInputShapes; + int mNumFeatures = 0; std::vector mOutputNames; std::vector> mOutputShapes; diff --git a/dependencies/O2PhysicsDependencies.cmake b/dependencies/O2PhysicsDependencies.cmake index d1a2a6280ac..807755aa3b8 100644 --- a/dependencies/O2PhysicsDependencies.cmake +++ b/dependencies/O2PhysicsDependencies.cmake @@ -25,5 +25,6 @@ find_package(fjcontrib) set_package_properties(fjcontrib PROPERTIES TYPE REQUIRED) find_package(ONNXRuntime) +find_package(ONNX) feature_summary(WHAT ALL FATAL_ON_MISSING_REQUIRED_PACKAGES) From 706de99cc52062ac8c4f8f6d2b24c740036502c0 Mon Sep 17 00:00:00 2001 From: ALICE Action Bot Date: Tue, 2 Jun 2026 12:20:04 +0000 Subject: [PATCH 2/2] Please consider the following formatting changes --- Tools/ML/model.h | 45 +++++++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/Tools/ML/model.h b/Tools/ML/model.h index 9bb434ec59f..df47f10b6a3 100644 --- a/Tools/ML/model.h +++ b/Tools/ML/model.h @@ -138,13 +138,13 @@ class OnnxModel // Declaration of a raw graph input fed directly from an Arrow buffer. struct PreprocInput { - enum class Type { TrackFloat, // per-track float column [N] - TrackInt32, // per-track int32 column [N] - TrackUint8, // per-track uint8 column [N] - TrackInt8, // per-track int8 column [N] - TrackBool, // per-track bool mask [N] - CollisionFloat,// per-collision float array [C] - ScalarFloat }; // single scalar (e.g. mass) [1] + enum class Type { TrackFloat, // per-track float column [N] + TrackInt32, // per-track int32 column [N] + TrackUint8, // per-track uint8 column [N] + TrackInt8, // per-track int8 column [N] + TrackBool, // per-track bool mask [N] + CollisionFloat, // per-collision float array [C] + ScalarFloat }; // single scalar (e.g. mass) [1] std::string name; Type type; }; @@ -318,13 +318,30 @@ class OnnxModel std::vector dims = {-1}; std::vector sym = {"N"}; switch (pin.type) { - case PreprocInput::Type::TrackFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; break; - case PreprocInput::Type::TrackInt32: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; break; - case PreprocInput::Type::TrackUint8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; break; - case PreprocInput::Type::TrackInt8: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; break; - case PreprocInput::Type::TrackBool: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; break; - case PreprocInput::Type::CollisionFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; sym = {"C"}; break; - case PreprocInput::Type::ScalarFloat: et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; dims = {1}; sym = {""}; break; + case PreprocInput::Type::TrackFloat: + et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + break; + case PreprocInput::Type::TrackInt32: + et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + break; + case PreprocInput::Type::TrackUint8: + et = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + break; + case PreprocInput::Type::TrackInt8: + et = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + break; + case PreprocInput::Type::TrackBool: + et = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; + break; + case PreprocInput::Type::CollisionFloat: + et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + sym = {"C"}; + break; + case PreprocInput::Type::ScalarFloat: + et = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + dims = {1}; + sym = {""}; + break; } Ort::TensorTypeAndShapeInfo tInfo(et, dims, &sym); auto typeInfo = Ort::TypeInfo::CreateTensorInfo(tInfo.GetConst());