Commit 1acee842 authored by Wouter Deconinck's avatar Wouter Deconinck
Browse files

Proof of concept algorithm that uses tensorflow-lite

parent aae2c9f6
......@@ -22,6 +22,15 @@ if(ENABLE_CLANG_TIDY)
endif()
endif()
find_library(tflite_library tensorflow-lite /opt/local/lib REQUIRED)
find_file(tflite_interpreter tensorflow/lite/interpreter.h /opt/local/include REQUIRED)
get_filename_component(tflite ${tflite_library} NAME)
get_filename_component(tflite_LIBRARY_DIR ${tflite_library} DIRECTORY)
link_directories(${tflite_LIBRARY_DIR})
get_filename_component(tflite_INCLUDE_DIR ${tflite_interpreter} DIRECTORY)
include_directories(${tflite_INCLUDE_DIR}/../..)
message(STATUS "tensorflow-lite: ${tflite} ${tflite_LIBRARY_DIR} ${tflite_INCLUDE_DIR}")
find_package(EICD REQUIRED)
find_package(NPDet REQUIRED)
......
......@@ -15,6 +15,7 @@ gaudi_add_module(JugRecoPlugins
src/components/CalorimeterHitsEtaPhiProjector.cpp
src/components/CalorimeterHitsMerger.cpp
src/components/CalorimeterIslandCluster.cpp
src/components/ClusterIdentification.cpp
src/components/EnergyPositionClusterMerger.cpp
src/components/ClusterRecoCoG.cpp
src/components/ParticleCollector.cpp
......@@ -31,6 +32,7 @@ gaudi_add_module(JugRecoPlugins
src/components/InclusiveKinematicseSigma.cpp
src/components/PhotoMultiplierReco.cpp
LINK
${tflite}
Gaudi::GaudiAlgLib Gaudi::GaudiKernel
JugBase
ROOT::Core ROOT::RIO ROOT::Tree
......
#include <algorithm>
#include "Gaudi/Property.h"
#include "GaudiAlg/GaudiAlgorithm.h"
#include "GaudiAlg/GaudiTool.h"
#include "GaudiAlg/Transformer.h"
#include "GaudiKernel/PhysicalConstants.h"
#include "GaudiKernel/RndmGenerators.h"
#include "GaudiKernel/ToolHandle.h"
#include "JugBase/DataHandle.h"
#include "JugBase/IGeoSvc.h"
#include "JugBase/UniqueID.h"
// Tensorflow headers
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"
// Event Model related classes
#include "eicd/ClusterCollection.h"
using namespace Gaudi::Units;
namespace Jug::Reco {
/** Simple cluster identification algorithm using ML.
*
* \ingroup reco
*/
class ClusterIdentification : public GaudiAlgorithm, AlgorithmIDMixin<> {
public:
DataHandle<eic::ClusterCollection> m_inputClusterCollection{"inputClusterCollection", Gaudi::DataHandle::Reader, this};
Gaudi::Property<std::string> m_modelTFLiteFile{this, "modelTFLiteFile", ""};
// interpreter
std::unique_ptr<tflite::Interpreter> m_interpreter;
ClusterIdentification(const std::string& name, ISvcLocator* svcLoc)
: GaudiAlgorithm(name, svcLoc)
, AlgorithmIDMixin<>(name, info()) {
declareProperty("inputClusterCollection", m_inputClusterCollection, "");
declareProperty("modelTFLiteFile", m_modelTFLiteFile, "");
}
StatusCode initialize() override
{
if (GaudiAlgorithm::initialize().isFailure()) {
return StatusCode::FAILURE;
}
// load model from file
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(m_modelTFLiteFile.value().data());
// build interpreter from model
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder builder(*model, resolver);
builder(&m_interpreter);
// allocate tensors for interpreter
m_interpreter->AllocateTensors();
// debug
printf("=== Pre-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(m_interpreter.get());
return StatusCode::SUCCESS;
}
StatusCode execute() override
{
// input collections
const auto& clusters = *m_inputClusterCollection.get();
// fill input tensors
double* input0 = m_interpreter->typed_input_tensor<double>(0);
// run inference
if (m_interpreter->Invoke() != kTfLiteOk) return StatusCode::FAILURE;
// debug
printf("\n\n=== Post-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(m_interpreter.get());
// read output tensors
double* output = m_interpreter->typed_output_tensor<double>(0);
return StatusCode::SUCCESS;
}
};
DECLARE_COMPONENT(ClusterIdentification)
} // namespace Jug::Reco
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment