diff --git a/benchmarks/imaging_shower_ML/scripts/check_edep_dists.py b/benchmarks/imaging_shower_ML/scripts/check_edep_dists.py
new file mode 100644
index 0000000000000000000000000000000000000000..404855b1c6fac23883ad7b3537d67b4013edc123
--- /dev/null
+++ b/benchmarks/imaging_shower_ML/scripts/check_edep_dists.py
@@ -0,0 +1,135 @@
+import ROOT
+import os
+import gc
+import numpy as np
+import pandas as pd
+import argparse
+import sys
+import matplotlib.pyplot as plt
+
+
+# read from RDataFrame and flatten a given collection, return pandas dataframe
+def flatten_collection(rdf, collection, cols=None, event_colname='event'):
+    if not cols:
+        cols = [str(c) for c in rdf.GetColumnNames() if str(c).startswith('{}.'.format(collection))]
+    else:
+        cols = ['{}.{}'.format(collection, c) for c in cols]
+    if not cols:
+        print('cannot find any branch under collection {}'.format(collection))
+        return pd.DataFrame()
+
+    data = rdf.AsNumpy(cols)
+    # flatten the data, add an event id to identify clusters from different events
+    evns = []
+    for i, vec in enumerate(data[cols[0]]):
+        evns += [i]*vec.size()
+    for n, vals in data.items():
+        # make sure ints are not converted to floats
+        typename = vals[0].__class__.__name__.lower()
+        dtype = np.int64 if 'int' in typename or 'long' in typename else np.float32
+        # type safe creation
+        data[n] = np.asarray([v for vec in vals for v in vec], dtype=dtype)
+    # build data frame
+    dfp = pd.DataFrame({c: pd.Series(v) for c, v in data.items()})
+    dfp.loc[:, event_colname] = evns
+    return dfp
+
+
+# load root macros, input is an argument string
+def load_root_macros(arg_macros):
+    for path in arg_macros.split(','):
+        path = path.strip()
+        if os.path.exists(path):
+            ROOT.gROOT.Macro(path)
+        else:
+            print('\"{}\" does not exist, skip loading it.'.format(path))
+
+
+def cartesian_to_polar(x, y, z):
+    r = np.sqrt(x**2 + y**2 + z**2)
+    rc = np.sqrt(x**2 + y**2)
+    theta = np.arccos(z / r)
+    phi = np.arctan2(y, x)
+    eta = -np.log(np.tan(theta / 2.))
+    return r, theta, phi, rc, eta
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('file', type=str, help='path to root file')
+    parser.add_argument('--sim', action='store_true', help='flag switch between sim and rec data (default rec)')
+    parser.add_argument('-o', type=str, default='.', dest='outdir', help='output directory')
+    # parser.add_argument('--compact', type=str, default='', dest='compact', help='compact file')
+    parser.add_argument('-m', '--macros', type=str, default='rootlogon.C', dest='macros',
+                         help='root macros to load (accept multiple paths separated by \",\")')
+    parser.add_argument('--branch', type=str, default='EcalBarrelImagingHitsReco', help='name of data branch (eic::CalorimeterHitCollection)')
+    parser.add_argument('--truth-branch', type=str, default='mcparticles', help='name of truth mc branch')
+    parser.add_argument('--edep-max', type=float, default=0., help='maximum edep (GeV) to plot')
+    parser.add_argument('--edep-nbins', type=int, default=200, help='number of bins')
+    parser.add_argument('--name-tag', type=str, default='test', help='name tag to save the file')
+    parser.add_argument('--samp-frac', type=float, default=1.0, help='sampling fraction')
+    args = parser.parse_args()
+
+    os.makedirs(args.outdir, exist_ok=True)
+    load_root_macros(args.macros)
+
+    # read data and mcparticles
+    rdf = ROOT.RDataFrame("events", args.file)
+
+    mc_branch = args.truth_branch
+    dfm = flatten_collection(rdf, mc_branch, ['genStatus', 'pdgID', 'ps.x', 'ps.y', 'ps.z', 'mass'])
+    dfm.rename(columns={c: c.replace(mc_branch + '.', '') for c in dfm.columns}, inplace=True)
+    # selete incident particles
+    dfm = dfm[dfm['genStatus'].isin([0, 1])]
+    # NOTE: assumed single particles
+    dfm = dfm.groupby('event').first()
+    # p, theta, phi, pT, eta = cartesian_to_polar(*dfm[['ps.x', 'ps.y', 'ps.z']].values.T)
+
+    if args.sim:
+        df = flatten_collection(rdf, args.branch, ['energyDeposit'])
+        df.rename(columns={c: c.replace(args.branch + '.', '') for c in df.columns}, inplace=True)
+        df.rename(columns={'energyDeposit': 'energy'}, inplace=True)
+    else:
+        df = flatten_collection(rdf, args.branch, ['layer', 'energy', 'position.x', 'position.y', 'position.z'])
+        df.rename(columns={c: c.replace(args.branch + '.', '') for c in df.columns}, inplace=True)
+
+    dfe = df.groupby('event')['energy'].sum().reset_index()
+    # determine histrogram bins
+    if args.edep_max <= 0.:
+        args.edep_max = dfe['energy'].quantile(0.99)*1.2
+    bins = np.linspace(0., args.edep_max, args.edep_nbins + 1)
+    bincenters = (bins[1:] + bins[:-1])/2.
+
+    # get particle types
+    fig, ax = plt.subplots(figsize=(16, 9), dpi=120, gridspec_kw={'left': 0.15, 'right': 0.95})
+    ax.set_xlabel('Energy Deposit / {:.2f} (GeV)'.format(args.samp_frac), fontsize=24)
+    ax.set_ylabel('Normalized Counts', fontsize=24)
+    ax.set_yscale('log')
+    ax.grid(linestyle=':')
+    ax.tick_params(labelsize=24)
+    ax.set_axisbelow(True)
+
+    hist_vals, hist_cols = [], []
+    pdgbase = ROOT.TDatabasePDG()
+    for pdgid in dfm['pdgID'].unique():
+        particle = pdgbase.GetParticle(int(pdgid))
+        if not particle:
+            print("Unknown pdgcode {}, they are ignored".format(int(pdgid)))
+            continue
+        events_indices = dfm[dfm.loc[:, 'pdgID'] == pdgid].index.unique()
+        print("{} entries of particle {}".format(len(events_indices), particle.GetName()))
+
+        dfe_part = dfe.loc[dfe['event'].isin(events_indices)]
+
+        edep_vals, _, _ = ax.hist(dfe_part['energy'] / args.samp_frac,
+                                  weights = [1. / dfe_part.shape[0]]*dfe_part.shape[0],
+                                  histtype='step', bins=bins, label=particle.GetName())
+        hist_vals.append(edep_vals)
+        hist_cols.append(particle.GetName())
+
+    pd.DataFrame(index=bincenters, data=np.vstack(hist_vals).T, columns=hist_cols)\
+      .to_csv('{}.csv'.format(args.name_tag))
+    ax.legend(fontsize=24)
+    fig.savefig('{}.png'.format(args.name_tag))
+
+