From 26b2146873e436ad2ab7c04efbd6f31c222d3944 Mon Sep 17 00:00:00 2001
From: Dmitry Kalinkin <dmitry.kalinkin@gmail.com>
Date: Thu, 17 Oct 2024 23:42:29 -0400
Subject: [PATCH] Add ML training for EcalEndcapNClusterParticleIDs

---
 .gitlab-ci.yml                       |   4 +-
 Snakefile                            |   1 +
 benchmarks/calo_pid/Snakefile        |  97 +++++++
 benchmarks/calo_pid/calo_pid.org     | 364 +++++++++++++++++++++++++++
 benchmarks/calo_pid/config.yml       |  45 ++++
 benchmarks/calo_pid/requirements.txt |   6 +
 6 files changed, 516 insertions(+), 1 deletion(-)
 create mode 100644 benchmarks/calo_pid/Snakefile
 create mode 100644 benchmarks/calo_pid/calo_pid.org
 create mode 100644 benchmarks/calo_pid/config.yml
 create mode 100644 benchmarks/calo_pid/requirements.txt

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index e313531b..eb6563ff 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -127,6 +127,7 @@ get_data:
 include: 
   - local: 'benchmarks/backgrounds/config.yml'
   - local: 'benchmarks/backwards_ecal/config.yml'
+  - local: 'benchmarks/calo_pid/config.yml'
   - local: 'benchmarks/ecal_gaps/config.yml'
   - local: 'benchmarks/tracking_detectors/config.yml'
   - local: 'benchmarks/tracking_performances/config.yml'
@@ -159,8 +160,9 @@ deploy_results:
     - "collect_results:backwards_ecal"
     - "collect_results:barrel_ecal"
     - "collect_results:barrel_hcal"
-    - "collect_results:lfhcal"
+    - "collect_results:calo_pid"
     - "collect_results:ecal_gaps"
+    - "collect_results:lfhcal"
     - "collect_results:material_scan"
     - "collect_results:pid"
     - "collect_results:tracking_performance"
diff --git a/Snakefile b/Snakefile
index 1301c922..14e28f0e 100644
--- a/Snakefile
+++ b/Snakefile
@@ -33,6 +33,7 @@ def find_epic_libraries():
 include: "benchmarks/backgrounds/Snakefile"
 include: "benchmarks/backwards_ecal/Snakefile"
 include: "benchmarks/barrel_ecal/Snakefile"
+include: "benchmarks/calo_pid/Snakefile"
 include: "benchmarks/ecal_gaps/Snakefile"
 include: "benchmarks/material_scan/Snakefile"
 include: "benchmarks/tracking_performances/Snakefile"
diff --git a/benchmarks/calo_pid/Snakefile b/benchmarks/calo_pid/Snakefile
new file mode 100644
index 00000000..a93cfc40
--- /dev/null
+++ b/benchmarks/calo_pid/Snakefile
@@ -0,0 +1,97 @@
+def format_energy_for_dd4hep(s):
+    return s.rstrip("kMGeV") + "*" + s.lstrip("0123456789")
+
+rule calo_pid_sim:
+    input:
+        warmup="warmup/{DETECTOR_CONFIG}.edm4hep.root",
+        geometry_lib=find_epic_libraries(),
+    output:
+        "sim_output/calo_pid/{DETECTOR_CONFIG}/{PARTICLE}/{ENERGY_MIN}to{ENERGY_MAX}/{THETA_MIN}to{THETA_MAX}deg/{PARTICLE}_{ENERGY}_{THETA_MIN}to{THETA_MAX}deg.{INDEX}.edm4hep.root",
+    log:
+        "sim_output/calo_pid/{DETECTOR_CONFIG}/{PARTICLE}/{ENERGY_MIN}to{ENERGY_MAX}/{THETA_MIN}to{THETA_MAX}deg/{PARTICLE}_{ENERGY}_{THETA_MIN}to{THETA_MAX}deg.{INDEX}.edm4hep.root.log",
+    wildcard_constraints:
+        PARTICLE="(e-|pi-)",
+        ENERGY_MIN="[0-9]+[kMG]eV",
+        ENERGY_MAX="[0-9]+[kMG]eV",
+        THETA_MIN="[0-9]+",
+        THETA_MAX="[0-9]+",
+        INDEX="\d{4}",
+    params:
+        N_EVENTS=1000,
+        SEED=lambda wildcards: "1" + wildcards.INDEX,
+        DETECTOR_PATH=os.environ["DETECTOR_PATH"],
+        DETECTOR_CONFIG=lambda wildcards: wildcards.DETECTOR_CONFIG,
+        ENERGY_MIN=lambda wildcards: format_energy_for_dd4hep(wildcards.ENERGY_MIN),
+        ENERGY_MAX=lambda wildcards: format_energy_for_dd4hep(wildcards.ENERGY_MAX),
+        DD4HEP_HASH=get_spack_package_hash("dd4hep"),
+        NPSIM_HASH=get_spack_package_hash("npsim"),
+    cache: True
+    shell:
+        """
+set -m # monitor mode to prevent lingering processes
+exec ddsim \
+  --runType batch \
+  --enableGun \
+  --gun.momentumMin "{params.ENERGY_MIN}" \
+  --gun.momentumMax "{params.ENERGY_MAX}" \
+  --gun.thetaMin "{wildcards.THETA_MIN}*deg" \
+  --gun.thetaMax "{wildcards.THETA_MAX}*deg" \
+  --gun.particle {wildcards.PARTICLE} \
+  --gun.distribution eta \
+  --random.seed {params.SEED} \
+  --filter.tracker edep0 \
+  -v WARNING \
+  --numberOfEvents {params.N_EVENTS} \
+  --compactFile {params.DETECTOR_PATH}/{params.DETECTOR_CONFIG}.xml \
+  --outputFile {output}
+"""
+
+
+rule calo_pid_recon:
+    input:
+        "sim_output/calo_pid/{DETECTOR_CONFIG}/{PARTICLE}/{ENERGY}/{PHASE_SPACE}/{PARTICLE}_{ENERGY}_{PHASE_SPACE}.{INDEX}.edm4hep.root",
+    output:
+        "sim_output/calo_pid/{DETECTOR_CONFIG}/{PARTICLE}/{ENERGY}/{PHASE_SPACE}/{PARTICLE}_{ENERGY}_{PHASE_SPACE}.{INDEX}.eicrecon.tree.edm4eic.root",
+    log:
+        "sim_output/calo_pid/{DETECTOR_CONFIG}/{PARTICLE}/{ENERGY}/{PHASE_SPACE}/{PARTICLE}_{ENERGY}_{PHASE_SPACE}.{INDEX}.eicrecon.tree.edm4eic.root.log",
+    wildcard_constraints:
+        INDEX="\d{4}",
+    shell: """
+set -m # monitor mode to prevent lingering processes
+exec env DETECTOR_CONFIG={wildcards.DETECTOR_CONFIG} \
+  eicrecon {input} -Ppodio:output_file={output} \
+  -Ppodio:output_collections=MCParticles,EcalEndcapNRecHits,EcalEndcapNClusters,EcalEndcapNParticleIDInput_features,EcalEndcapNParticleIDTarget
+"""
+
+
+rule calo_pid:
+    input:
+        electrons=expand(
+            "sim_output/calo_pid/{{DETECTOR_CONFIG}}/{PARTICLE}/{ENERGY}/{PHASE_SPACE}/{PARTICLE}_{ENERGY}_{PHASE_SPACE}.{INDEX:04d}.eicrecon.tree.edm4eic.root",
+            PARTICLE=["e-"],
+            ENERGY=["100MeVto20GeV"],
+            PHASE_SPACE=["130to177deg"],
+            INDEX=range(100),
+        ),
+        pions=expand(
+            "sim_output/calo_pid/{{DETECTOR_CONFIG}}/{PARTICLE}/{ENERGY}/{PHASE_SPACE}/{PARTICLE}_{ENERGY}_{PHASE_SPACE}.{INDEX:04d}.eicrecon.tree.edm4eic.root",
+            PARTICLE=["pi-"],
+            ENERGY=["100MeVto20GeV"],
+            PHASE_SPACE=["130to177deg"],
+            INDEX=range(100),
+        ),
+        matplotlibrc=".matplotlibrc",
+        script="benchmarks/calo_pid/calo_pid.py",
+    output:
+        directory("results/{DETECTOR_CONFIG}/calo_pid")
+    shell:
+        """
+env \
+MATPLOTLIBRC={input.matplotlibrc} \
+DETECTOR_CONFIG={wildcards.DETECTOR_CONFIG} \
+PLOT_TITLE={wildcards.DETECTOR_CONFIG} \
+INPUT_ELECTRONS="{input.electrons}" \
+INPUT_PIONS="{input.pions}" \
+OUTPUT_DIR={output} \
+python {input.script}
+"""
diff --git a/benchmarks/calo_pid/calo_pid.org b/benchmarks/calo_pid/calo_pid.org
new file mode 100644
index 00000000..a46aec02
--- /dev/null
+++ b/benchmarks/calo_pid/calo_pid.org
@@ -0,0 +1,364 @@
+#+PROPERTY: header-args:jupyter-python :session /jpy:localhost#8888:benchmark :async yes :results drawer :exports both
+
+#+TITLE: ePIC EEEMCal calorimetry PID
+#+AUTHOR: Dmitry Kalinkin
+#+OPTIONS: d:t
+
+#+LATEX_CLASS_OPTIONS: [9pt,letter]
+#+BIND: org-latex-image-default-width ""
+#+BIND: org-latex-image-default-option "scale=0.3"
+#+BIND: org-latex-images-centered nil
+#+BIND: org-latex-minted-options (("breaklines") ("bgcolor" "black!5") ("frame" "single"))
+#+LATEX_HEADER: \usepackage[margin=1in]{geometry}
+#+LATEX_HEADER: \setlength{\parindent}{0pt}
+#+LATEX: \sloppy
+
+#+begin_src jupyter-python :results silent
+import os
+from pathlib import Path
+
+import awkward as ak
+import boost_histogram as bh
+import numpy as np
+import vector
+import uproot
+from sklearn.metrics import roc_curve
+
+vector.register_awkward()
+#+end_src   
+
+* Parameters
+
+#+begin_src jupyter-python :results silent
+DETECTOR_CONFIG=os.environ.get("DETECTOR_CONFIG")
+PLOT_TITLE=os.environ.get("PLOT_TITLE")
+INPUT_PIONS=os.environ.get("INPUT_PIONS", "").split(" ")
+INPUT_ELECTRONS=os.environ.get("INPUT_ELECTRONS", "").split(" ")
+
+output_dir=Path(os.environ.get("OUTPUT_DIR", "./"))
+output_dir.mkdir(parents=True, exist_ok=True)
+#+end_src
+
+* Plotting setup
+                
+#+begin_src jupyter-python :results silent
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+       
+def setup_presentation_style():
+    mpl.rcParams.update(mpl.rcParamsDefault)
+    plt.style.use('ggplot')
+    plt.rcParams.update({
+        'axes.labelsize': 8,
+        'axes.titlesize': 9,
+        'figure.titlesize': 9,
+        'figure.figsize': (4, 3),
+        'legend.fontsize': 7,
+        'xtick.labelsize': 8,
+        'ytick.labelsize': 8,
+        'pgf.rcfonts': False,
+    })
+
+setup_presentation_style()
+#+end_src       
+
+* Analysis
+
+#+begin_src jupyter-python
+def filter_pointing(events):
+    part_momentum = ak.zip({
+        "m": events["MCParticles.mass"],
+        "px": events["MCParticles.momentum.x"],
+        "py": events["MCParticles.momentum.y"],
+        "pz": events["MCParticles.momentum.z"],
+    }, with_name="Momentum4D")
+    cond = (part_momentum.eta[:,0] > -3.5) & (part_momentum.eta[:,0] < -2.)
+    return events[cond]
+
+e = filter_pointing(uproot.concatenate({filename: "events" for filename in INPUT_ELECTRONS}, filter_name=["MCParticles.*", "*EcalEndcapN*"]))
+pi = filter_pointing(uproot.concatenate({filename: "events" for filename in INPUT_PIONS}, filter_name=["MCParticles.*", "*EcalEndcapN*"]))
+
+e_train = e[:len(pi)//2]
+pi_train = pi[:len(pi)//2]
+e_eval = e[len(pi)//2:]
+pi_eval = pi[len(pi)//2:]
+#+end_src
+
+#+RESULTS:
+:results:
+:end:
+
+#+begin_src jupyter-python
+nums = ak.num(pi_train["_EcalEndcapNParticleIDInput_features_floatData"], axis=1)
+num_features = ak.min(nums[nums > 0])
+print(f"{num_features} features")
+nums = ak.num(pi_train["_EcalEndcapNParticleIDTarget_int64Data"], axis=1)
+num_targets = ak.min(nums[nums > 0])
+print(f"{num_targets} targets")
+
+pi_train_leading_cluster_ixs = ak.argsort(pi_train["EcalEndcapNClusters.energy"], ascending=False)[:,:1]
+e_train_leading_cluster_ixs = ak.argsort(e_train["EcalEndcapNClusters.energy"], ascending=False)[:,:1]
+pi_train_x = ak.flatten(ak.unflatten(pi_train["_EcalEndcapNParticleIDInput_features_floatData"], num_features, axis=1)[pi_train_leading_cluster_ixs])
+pi_train_y = ak.flatten(ak.unflatten(pi_train["_EcalEndcapNParticleIDTarget_int64Data"], num_targets, axis=1)[pi_train_leading_cluster_ixs])
+e_train_x = ak.flatten(ak.unflatten(e_train["_EcalEndcapNParticleIDInput_features_floatData"], num_features, axis=1)[e_train_leading_cluster_ixs])
+e_train_y = ak.flatten(ak.unflatten(e_train["_EcalEndcapNParticleIDTarget_int64Data"], num_targets, axis=1)[e_train_leading_cluster_ixs])
+train_x = ak.concatenate([pi_train_x, e_train_x])
+train_y = ak.concatenate([pi_train_y, e_train_y])
+
+pi_eval_leading_cluster_ixs = ak.argsort(pi_eval["EcalEndcapNClusters.energy"], ascending=False)[:,:1]
+e_eval_leading_cluster_ixs = ak.argsort(e_eval["EcalEndcapNClusters.energy"], ascending=False)[:,:1]
+pi_eval_x = ak.flatten(ak.unflatten(pi_eval["_EcalEndcapNParticleIDInput_features_floatData"], num_features, axis=1)[pi_eval_leading_cluster_ixs])
+pi_eval_y = ak.flatten(ak.unflatten(pi_eval["_EcalEndcapNParticleIDTarget_int64Data"], num_targets, axis=1)[pi_eval_leading_cluster_ixs])
+e_eval_x = ak.flatten(ak.unflatten(e_eval["_EcalEndcapNParticleIDInput_features_floatData"], num_features, axis=1)[e_eval_leading_cluster_ixs])
+e_eval_y = ak.flatten(ak.unflatten(e_eval["_EcalEndcapNParticleIDTarget_int64Data"], num_targets, axis=1)[e_eval_leading_cluster_ixs])
+eval_x = ak.concatenate([pi_eval_x, e_eval_x])
+eval_y = ak.concatenate([pi_eval_y, e_eval_y])
+#+end_src
+
+#+RESULTS:
+:results:
+: 11 features
+: 2 targets
+:end:
+
+#+begin_src jupyter-python
+
+#+end_src
+
+#+RESULTS:
+:results:
+#+begin_export html
+<pre>[5.11,
+ 0.0424,
+ 3.03,
+ 2.16,
+ 17.7,
+ 8.32,
+ -4.54e-07,
+ 0.000456,
+ 0,
+ 69.2,
+ 0]
+------------------
+type: 11 * float32</pre>
+#+end_export
+:end:
+
+#+begin_src jupyter-python
+ak.sum((ak.num(pi_train_leading_cluster_ixs) != 0)), ak.num(pi_train_leading_cluster_ixs, axis=0)
+#+end_src
+
+#+RESULTS:
+:results:
+| 87721 | array | (88210) |
+:end:
+
+#+begin_src jupyter-python
+plt.hist(pi_eval_x[:,0])
+plt.hist(e_eval_x[:,0], alpha=0.5)
+plt.show()
+#+end_src
+
+#+RESULTS:
+:results:
+[[file:./.ob-jupyter/5381c9bd149f0bb8855bf539e7ce8ef927a2e1a9.png]]
+:end:
+
+#+begin_src jupyter-python
+"""
+fig, axs = plt.subplots(2, 4, sharex=True, sharey=True, figsize=(15, 6))
+fig_log, axs_log = plt.subplots(2, 4, sharex=True, sharey=True, figsize=(15, 6))
+fig_roc, axs_roc = plt.subplots(2, 4, sharex=True, sharey=True, figsize=(15, 6))
+
+axs = np.ravel(np.array(axs))
+axs_log = np.ravel(np.array(axs_log))
+axs_roc = np.ravel(np.array(axs_roc))
+
+rocs = {}
+
+for ix, energy in enumerate(energies):
+  for c in [False, True]:
+    energy_value = float(energy.replace("GeV", "").replace("MeV", "e-3"))
+    if c:
+      clf_label = "leading cluster"
+    else:
+      clf_label = "sum all hits"
+    def clf(events):
+      if c:       
+        return ak.drop_none(ak.max(events["EcalEndcapNClusters.energy"], axis=-1)) / energy_value
+      else:
+        return ak.sum(events["EcalEndcapNRecHits.energy"], axis=-1) / energy_value
+    e_pred = clf(e_eval[energy])
+    pi_pred = clf(pi_eval[energy])
+
+    for do_log in [False, True]:
+        plt.sca(axs[ix])
+        plt.hist(e_pred, bins=np.linspace(0., 1.01, 101), label=rf"$e^-$ {clf_label}")
+        plt.hist(pi_pred, bins=np.linspace(0., 1.01, 101), label=rf"$\pi^-$ {clf_label}", histtype="step")
+        plt.title(f"{energy}")
+        plt.legend()
+        if do_log: plt.yscale("log")
+
+    plt.sca(axs_roc[ix])
+    fpr, tpr, _ = roc_curve(
+        np.concatenate([np.ones_like(e_pred), np.zeros_like(pi_pred)]),
+        np.concatenate([e_pred, pi_pred]),
+    )
+    cond = fpr != 0 # avoid infinite rejection (region of large uncertainty)
+    cond &= tpr != 1 # avoid linear interpolation (region of large uncertainty)
+    def mk_interp(tpr, fpr):
+        def interp(eff):
+            return np.interp(eff, tpr, fpr)
+        return interp
+    rocs.setdefault(clf_label, {})[energy] = mk_interp(tpr, fpr)
+    plt.plot(tpr[cond] * 100, 1 / fpr[cond], label=f"{clf_label}")
+    plt.title(f"{energy}")
+    plt.legend()
+
+fig.show()
+plt.close(fig_log)
+fig_roc.show()
+
+plt.figure()
+for clf_label, roc in rocs.items():
+    plt.plot(
+        [float(energy.replace("GeV", "").replace("MeV", "e-3")) for energy in energies],
+        [1 / roc[energy](0.95) for energy in energies],
+        marker=".",
+        label=f"{clf_label}",
+    )
+plt.legend()
+plt.show()
+"""
+#+end_src
+
+#+begin_src jupyter-python
+import catboost
+clf = {}
+
+from sklearn.metrics import roc_curve
+roc = {}
+
+clf = catboost.CatBoostClassifier(loss_function="CrossEntropy", verbose=0, n_estimators=1000)
+clf.fit(
+    train_x.to_numpy(),
+    train_y.to_numpy()[:,1], # index 1 = is electron
+)
+plt.hist(clf.predict_proba(e_eval_x.to_numpy())[:,1], bins=np.linspace(0., 1.01, 101), label=r"$e^-$")
+plt.hist(clf.predict_proba(pi_eval_x.to_numpy())[:,1], bins=np.linspace(0., 1.01, 101), label=r"$\pi^-$", histtype="step")
+plt.xlabel("Classifier's probability prediction", loc="right")
+plt.ylabel("Number of events", loc="top")
+plt.legend(loc="upper center")
+plt.savefig(output_dir / "predict_proba.pdf", bbox_inches="tight")
+plt.show()
+#+end_src
+
+#+begin_src jupyter-python
+energy_bin_edges = np.arange(0., 20. + 1e-7, 1.)
+
+_eval_energy = eval_x[:,0]
+
+for energy_bin_ix, (energy_bin_low, energy_bin_high) in enumerate(zip(energy_bin_edges[:-1], energy_bin_edges[1:])):
+   cond = (_eval_energy >= energy_bin_low) & (_eval_energy < energy_bin_high)
+   print(energy_bin_low, energy_bin_high, ak.sum(cond))
+
+   pi_cond = (pi_eval_x[:,0] >= energy_bin_low) & (pi_eval_x[:,0] < energy_bin_high)
+   e_cond = (e_eval_x[:,0] >= energy_bin_low) & (e_eval_x[:,0] < energy_bin_high)
+   plt.hist(pi_eval_x[pi_cond][:,1], bins=np.linspace(0., 1.01, 101))
+   plt.hist(e_eval_x[e_cond][:,1], bins=np.linspace(0., 1.01, 101))
+   plt.yscale("log")
+   plt.show()
+   plt.clf()
+
+   fpr, tpr, _ = roc_curve(
+      eval_y[cond][:,1],
+      #eval_x[cond][:,1],
+      clf.predict_proba(eval_x[cond].to_numpy())[:,1],
+   )
+   cond = fpr != 0 # avoid infinite rejection (region of large uncertainty)
+   cond &= tpr != 1 # avoid linear interpolation (region of large uncertainty)
+   #cond &= tpr > 0.5
+   plt.plot(tpr[cond] * 100, 1 / fpr[cond])
+
+   def mk_interp(tpr, fpr):
+       def interp(eff):
+           return np.interp(eff, tpr, fpr)
+       return interp
+   roc[energy_bin_ix] = mk_interp(tpr, fpr)
+
+   plt.xlabel("Electron efficiency, %", loc="right")
+   plt.ylabel("Pion rejection factor", loc="top")
+   plt.title(rf"${energy_bin_low:.1f} < |\vec{{p}}| < {energy_bin_high:.1f}$ GeV")
+   plt.legend(loc="lower left")
+   plt.yscale("log")
+   plt.savefig(output_dir / f"roc_{energy_bin_low:.1f}_{energy_bin_high:.1f}.pdf", bbox_inches="tight")
+   plt.show()
+   plt.clf()
+
+plt.plot(
+    (energy_bin_edges[:-1] + energy_bin_edges[1:]) / 2,
+    [1 / roc[energy_bin_ix](0.95) for energy_bin_ix in range(len(energy_bin_edges) - 1)],
+    marker=".",
+    label="",
+)
+plt.yscale("log")
+plt.legend()
+plt.xlabel("Energy, GeV", loc="right")
+plt.ylabel("Pion rejection at 95%", loc="top")
+plt.savefig(output_dir / f"pion_rej.pdf", bbox_inches="tight")
+plt.show()
+#+end_src
+
+#+begin_src jupyter-python
+clf.save_model(
+    output_dir / "EcalEndcapN_pi_rejection.onnx",
+    format="onnx",
+    export_parameters={
+        "onnx_doc_string": "Classifier model for pion rejection in EEEMCal",
+        "onnx_graph_name": "CalorimeterParticleID_BinaryClassification",
+    }
+)
+import onnx
+
+model = onnx.load(output_dir / "EcalEndcapN_pi_rejection.onnx")
+onnx.checker.check_model(model)
+graph_def = onnx.helper.make_graph(
+    nodes=[model.graph.node[0]],
+    name=model.graph.name,
+    inputs=model.graph.input,
+    outputs=[model.graph.output[0], model.graph.value_info[0]],
+    initializer=model.graph.initializer
+)
+model_def = onnx.helper.make_model(graph_def, producer_name=model.producer_name)
+del model_def.opset_import[:]
+op_set = model_def.opset_import.add()
+op_set.domain = "ai.onnx.ml"
+op_set.version = 2
+model_def = onnx.shape_inference.infer_shapes(model_def)
+onnx.checker.check_model(model_def)
+onnx.save(model_def, output_dir / "EcalEndcapN_pi_rejection.onnx")
+#+end_src
+
+#+RESULTS:
+:results:
+:end:
+
+#+begin_src jupyter-python
+if "_EcalEndcapNParticleIDOutput_probability_tensor_floatData" in pi_train.fields:
+    nums = ak.num(pi_train["_EcalEndcapNParticleIDOutput_probability_tensor_floatData"], axis=1)
+    num_outputs = ak.min(nums[nums > 0])
+    print(f"{num_outputs} outputs")
+
+    pi_train_proba = ak.flatten(ak.unflatten(pi_train["_EcalEndcapNParticleIDOutput_probability_tensor_floatData"], num_outputs, axis=1)[pi_train_leading_cluster_ixs])
+    e_train_proba = ak.flatten(ak.unflatten(e_train["_EcalEndcapNParticleIDOutput_probability_tensor_floatData"], num_outputs, axis=1)[e_train_leading_cluster_ixs])
+    train_proba = ak.concatenate([pi_train_proba, e_train_proba])
+
+    pi_eval_proba = ak.flatten(ak.unflatten(pi_eval["_EcalEndcapNParticleIDOutput_probability_tensor_floatData"], num_outputs, axis=1)[pi_eval_leading_cluster_ixs])
+    e_eval_proba = ak.flatten(ak.unflatten(e_eval["_EcalEndcapNParticleIDOutput_probability_tensor_floatData"], num_outputs, axis=1)[e_eval_leading_cluster_ixs])
+    eval_proba = ak.concatenate([pi_eval_proba, e_eval_proba])
+
+    plt.hist(clf.predict_proba(eval_x.to_numpy())[:,1] - eval_proba[:,1].to_numpy())
+    plt.show()
+else:
+    print("EcalEndcapNParticleIDOutput not present")
+#+end_src
diff --git a/benchmarks/calo_pid/config.yml b/benchmarks/calo_pid/config.yml
new file mode 100644
index 00000000..8a65daa6
--- /dev/null
+++ b/benchmarks/calo_pid/config.yml
@@ -0,0 +1,45 @@
+sim:calo_pid:
+  extends: .det_benchmark
+  stage: simulate
+  parallel:
+    matrix:
+      - PARTICLE: ["e-", "pi-"]
+        INDEX_RANGE: [
+          "0 9",
+          "10 19",
+          "20 29",
+          "30 39",
+          "40 49",
+          "50 59",
+          "60 69",
+          "70 79",
+          "80 89",
+          "90 99",
+        ]
+  script:
+    - |
+      snakemake --cache --cores 5 \
+        $(seq --format="sim_output/calo_pid/epic_inner_detector/${PARTICLE}/100MeVto20GeV/130to177deg/${PARTICLE}_100MeVto20GeV_130to177deg.%04.f.eicrecon.tree.edm4eic.root" ${INDEX_RANGE})
+
+bench:calo_pid:
+  allow_failure: true # until inference merged into EICrecon
+  extends: .det_benchmark
+  stage: benchmarks
+  needs:
+    - ["sim:calo_pid"]
+  script:
+    - export PYTHONUSERBASE=$LOCAL_DATA_PATH/deps
+    - pip install -r benchmarks/calo_pid/requirements.txt
+    - snakemake --cache --cores 1 results/epic_inner_detector/calo_pid
+
+collect_results:calo_pid:
+  allow_failure: true # until inference merged into EICrecon
+  extends: .det_benchmark
+  stage: collect
+  needs:
+    - "bench:calo_pid"
+  script:
+    - ls -lrht
+    - mv results{,_save}/ # move results directory out of the way to preserve it
+    - snakemake --cores 1 --delete-all-output results/epic_inner_detector/calo_pid
+    - mv results{_save,}/
diff --git a/benchmarks/calo_pid/requirements.txt b/benchmarks/calo_pid/requirements.txt
new file mode 100644
index 00000000..dd564a91
--- /dev/null
+++ b/benchmarks/calo_pid/requirements.txt
@@ -0,0 +1,6 @@
+awkward >= 2.4.0
+catboost
+onnx
+scikit-learn
+uproot >= 5.2.0
+vector
-- 
GitLab