From a1b3797c807c765b292338e30fcc0da36d1ebba8 Mon Sep 17 00:00:00 2001
From: Chao Peng <cpeng@anl.gov>
Date: Sat, 17 Jul 2021 17:44:42 -0500
Subject: [PATCH] add a script to get layer ids

---
 .../imaging_ecal/scripts/get_layerids.py      | 72 +++++++++++++++++++
 1 file changed, 72 insertions(+)
 create mode 100644 benchmarks/imaging_ecal/scripts/get_layerids.py

diff --git a/benchmarks/imaging_ecal/scripts/get_layerids.py b/benchmarks/imaging_ecal/scripts/get_layerids.py
new file mode 100644
index 00000000..72457289
--- /dev/null
+++ b/benchmarks/imaging_ecal/scripts/get_layerids.py
@@ -0,0 +1,72 @@
+'''
+    A simple analysis script to extract some basic info of Monte-Carlo hits
+'''
+import os
+import DDG4
+import ROOT
+import pandas as pd
+import numpy as np
+import argparse
+from matplotlib import pyplot as plt
+import matplotlib.ticker as ticker
+
+
+# read from RDataFrame and flatten a given collection, return pandas dataframe
+def flatten_collection(rdf, collection, cols=None):
+    if not cols:
+        cols = [str(c) for c in rdf.GetColumnNames() if str(c).startswith('{}.'.format(collection))]
+    else:
+        cols = ['{}.{}'.format(collection, c) for c in cols]
+    if not cols:
+        print('cannot find any branch under collection {}'.format(collection))
+        return pd.DataFrame()
+
+    data = rdf.AsNumpy(cols)
+    # flatten the data, add an event id to identify clusters from different events
+    evns = []
+    for i, vec in enumerate(data[cols[0]]):
+        evns += [i]*vec.size()
+    for n, vals in data.items():
+        # make sure ints are not converted to floats
+        typename = vals[0].__class__.__name__.lower()
+        dtype = np.int64 if 'int' in typename or 'long' in typename else np.float64
+        # type safe creation
+        data[n] = np.asarray([v for vec in vals for v in vec], dtype=dtype)
+    # build data frame
+    dfp = pd.DataFrame({c: pd.Series(v) for c, v in data.items()})
+    dfp.loc[:, 'event'] = evns
+    return dfp
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('rec_file', help='Path to reconstruction output file.')
+    parser.add_argument('-o', dest='outdir', default='.', help='Output directory.')
+    parser.add_argument('-c', '--compact', dest='compact', required=True,
+            help='Top-level xml file of the detector description')
+    parser.add_argument('--collection', dest='coll', required=True,
+            help='Hits collection name in the reconstruction file')
+    parser.add_argument('--readout', dest='readout', required=True,
+            help='Readout name for the hits collection')
+    args = parser.parse_args()
+
+    # get hits
+    rdf_rec = ROOT.RDataFrame('events', args.rec_file)
+    df = flatten_collection(rdf_rec, args.coll)
+    df.rename(columns={c: c.replace(args.coll + '.', '') for c in df.columns}, inplace=True)
+
+    # initialize dd4hep detector
+    kernel = DDG4.Kernel()
+    description = kernel.detectorDescription()
+    kernel.loadGeometry("file:{}".format(args.compact))
+
+    decoder = description.readout(args.readout).idSpec().decoder()
+    lindex = decoder.index('layer')
+    get_layer_id = np.vectorize(lambda cid: decoder.get(cid, lindex))
+
+    df.loc[:, 'layerID'] = get_layer_id(df['cellID'].astype(int).values)
+    print(df[['cellID', 'layerID', 'position.x', 'position.y', 'position.z', 'energy']])
+
+    # always terminate dd4hep kernel
+    kernel.terminate()
+
-- 
GitLab