From 43c01cc4b5dfb3014dc8c03d370af984d28c6a5b Mon Sep 17 00:00:00 2001
From: Chao Peng <cpeng@anl.gov>
Date: Thu, 3 Nov 2022 21:45:03 +0000
Subject: [PATCH] Analysis script for checking merged hits for imaging barrel
 ecal

---
 .../imaging_ecal/options/hybrid_cluster.py    |   4 +-
 .../imaging_ecal/scripts/check_merged_hits.py | 157 ++++++++++++++++++
 .../imaging_ecal/scripts/get_layerids.py      | 100 +----------
 benchmarks/imaging_ecal/scripts/utils.py      |  98 ++++++++++-
 4 files changed, 265 insertions(+), 94 deletions(-)
 create mode 100644 benchmarks/imaging_ecal/scripts/check_merged_hits.py

diff --git a/benchmarks/imaging_ecal/options/hybrid_cluster.py b/benchmarks/imaging_ecal/options/hybrid_cluster.py
index 71e46e0a..5c50009c 100644
--- a/benchmarks/imaging_ecal/options/hybrid_cluster.py
+++ b/benchmarks/imaging_ecal/options/hybrid_cluster.py
@@ -160,8 +160,8 @@ if has_ecal_barrel_scfi:
         #OutputLevel=DEBUG,
         inputHitCollection=scfi_barrel_reco.outputHitCollection,
         outputHitCollection="EcalBarrelScFiGridReco",
-        fields=["fiber"],
-        fieldRefNumbers=[1],
+        fields=["fiber", "z"],
+        fieldRefNumbers=[1, 1],
         readoutClass="EcalBarrelScFiHits")
     algorithms.append(scfi_barrel_merger)
 
diff --git a/benchmarks/imaging_ecal/scripts/check_merged_hits.py b/benchmarks/imaging_ecal/scripts/check_merged_hits.py
new file mode 100644
index 00000000..f381ac45
--- /dev/null
+++ b/benchmarks/imaging_ecal/scripts/check_merged_hits.py
@@ -0,0 +1,157 @@
+'''
+    A simple analysis script to check raw hits vs. merged hits
+    *  ReadoutDecoder class solely depends on python xml parser and is thus much faster, but it may not work
+       if some of the features are not supported by the parser (but supported by dd4hep parser)
+    ** dd4hep decoder can also be used (more reliable but much slower), check get_layer_ids.py for the code
+    Chao Peng (ANL)
+'''
+import os
+import ROOT
+import pandas as pd
+import numpy as np
+import argparse
+from matplotlib import pyplot as plt
+import matplotlib.ticker as ticker
+from utils import ReadoutDecoder, flatten_collection
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('rec_file', help='Path to reconstruction output file.')
+    parser.add_argument('-o', dest='outdir', default='.', help='Output directory.')
+    parser.add_argument('-c', '--compact', dest='compact', required=True,
+            help='Top-level xml file of the detector description')
+    parser.add_argument('--raw-coll', dest='rcoll', required=True,
+            help='Raw hits collection name in the reconstruction file')
+    parser.add_argument('--merged-coll', dest='mcoll', required=True,
+            help='Merged hits collection name in the reconstruction file')
+    parser.add_argument('--readout', dest='readout', required=True,
+            help='Readout name for the hits collection')
+    parser.add_argument('--merged-fields', dest='mfields', required=True,
+            help='Merged fields (separated by \",\")')
+    args = parser.parse_args()
+
+    # decoder and merge mask
+    decoder = ReadoutDecoder(args.compact, args.readout)
+    mmask = np.uint64(0)
+    for field in args.mfields.split(','):
+        fmask = decoder.mask(field.strip())
+        mmask = np.bitwise_or(mmask, fmask)
+        print('{:20s} {:#066b}'.format('Field Mask - ' + field, fmask))
+    mmask = np.bitwise_not(mmask)
+    print('{:20s} {:#066b}'.format('Merging Mask', mmask))
+
+    # build dataframes from reconstructed root files
+    # get hits
+    columns = ['cellID', 'position.x', 'position.y', 'position.z', 'energy']
+    rdf_rec = ROOT.RDataFrame('events', args.rec_file)
+    # raw hits
+    dfr = flatten_collection(rdf_rec.Range(100), args.rcoll, columns)
+    dfr.rename(columns={c: c.replace(args.rcoll + '.', '') for c in dfr.columns}, inplace=True)
+    # merged hits
+    dfm = flatten_collection(rdf_rec, args.mcoll, columns)
+    dfm.rename(columns={c: c.replace(args.mcoll + '.', '') for c in dfm.columns}, inplace=True)
+
+    dfr.loc[:, 'masked_id'] = np.bitwise_and(mmask, dfr['cellID'])
+    dfm.loc[:, 'merged_id'] = np.bitwise_and(mmask, dfm['cellID'])
+
+    dfr = dfr.merge(dfm.groupby('event')['merged_id'].apply(np.array), on='event')
+    dfr.loc[:, 'raw_merge_num'] = dfr.apply(lambda r: np.count_nonzero(r.masked_id == r.merged_id), axis=1)
+
+    dfm = dfm.merge(dfr.groupby('event')['masked_id'].apply(np.array), on='event')
+    dfm.loc[:, 'merge_raw_num'] = dfm.apply(lambda r: np.count_nonzero(r.merged_id == r.masked_id), axis=1)
+
+    print(dfr[['masked_id', 'merged_id', 'raw_merge_num']])
+    print(dfm[['masked_id', 'merged_id', 'merge_raw_num']])
+
+    # dfr.to_csv('raw_hits.csv')
+    # dfm.to_csv('merged_hits.csv')
+
+    # merged hits (grid)
+    # dfm = pd.read_csv('merged_hits.csv', index_col=0)
+    ev_grpm = dfm.groupby('event')
+    # energy scaling (assume mean at 5 GeV)
+    escale = 5000./np.mean(ev_grpm['energy'].sum())
+    dfm['energy'] = dfm['energy']*escale
+    ev_ngrids = ev_grpm['energy'].size()
+    nev = len(ev_ngrids)
+
+    # number of grids per event
+    fig, ax = plt.subplots(figsize=(12, 9), dpi=160)
+    ax.hist(ev_ngrids, ec='k', bins=30, rwidth=1.0, weights=np.repeat(1./nev, nev))
+    ax.tick_params(labelsize=24)
+    ax.set_title('5.0 GeV $e^-$', fontsize=26)
+    ax.set_xlabel('# of Grids per Event', fontsize=24)
+    ax.set_ylabel('Normalized Counts', fontsize=24)
+    fig.savefig('grid_per_event.png')
+
+    # energy deposit per grid
+    fig, ax = plt.subplots(figsize=(12, 9), dpi=160)
+    ax.hist(dfm['energy'], ec='k', bins=30, rwidth=1.0, weights=np.repeat(1./len(dfm), len(dfm)))
+    ax.tick_params(labelsize=24)
+    ax.set_title('5.0 GeV $e^-$', fontsize=26)
+    ax.set_xlabel('Grid $E_{dep}$ (MeV)', fontsize=24)
+    ax.set_ylabel('Normalized Counts', fontsize=24)
+    fig.savefig('edep_per_grid.png')
+
+    # number of fired fibers per grid
+    fig, ax = plt.subplots(figsize=(12, 9), dpi=160)
+    ax.hist(dfm['merge_raw_num'], ec='k', bins=30, rwidth=1.0, weights=np.repeat(1./len(dfm), len(dfm)))
+    ax.tick_params(labelsize=24)
+    ax.set_title('5.0 GeV $e^-$', fontsize=26)
+    ax.set_xlabel('# of Fiber-Signals per Grid', fontsize=24)
+    ax.set_ylabel('Normalized Counts', fontsize=24)
+    fig.savefig('signals_per_grid.png')
+
+
+    # raw hits (fiber)
+    # dfr = pd.read_csv('raw_hits.csv', index_col=0)
+    dfr['energy'] = dfr['energy']*escale
+    ev_grpr = dfr.groupby('event')
+    ev_nfibers = ev_grpr['energy'].size()
+    nev = len(ev_nfibers)
+
+    # number of fiber per event
+    fig, ax = plt.subplots(figsize=(12, 9), dpi=160)
+    ax.hist(ev_nfibers, ec='k', bins=30, rwidth=1.0, weights=np.repeat(1./nev, nev))
+    ax.tick_params(labelsize=24)
+    ax.set_title('5.0 GeV $e^-$', fontsize=26)
+    ax.set_xlabel('# of Fibers per Event', fontsize=24)
+    ax.set_ylabel('Normalized Counts', fontsize=24)
+    fig.savefig('fiber_per_event.png')
+
+    # energy deposit per fiber
+    fig, ax = plt.subplots(figsize=(12, 9), dpi=160)
+    ax.hist(dfr['energy'], ec='k', bins=30, rwidth=1.0, weights=np.repeat(1./len(dfr), len(dfr)))
+    ax.tick_params(labelsize=24)
+    ax.set_title('5.0 GeV $e^-$', fontsize=26)
+    ax.set_xlabel('Fiber $E_{dep}$ (MeV)', fontsize=24)
+    ax.set_ylabel('Normalized Counts', fontsize=24)
+    fig.savefig('edep_per_fiber.png')
+
+    # number of fired fibers per grid
+    fig, ax = plt.subplots(figsize=(12, 9), dpi=160)
+    ax.hist(dfr['raw_merge_num'], ec='k', bins=np.arange(-0.5, 7.5, 1.0),
+            rwidth=1.0, weights=np.repeat(1./len(dfr), len(dfr)))
+    ax.tick_params(labelsize=24)
+    ax.set_title('5.0 GeV $e^-$', fontsize=26)
+    ax.set_xlabel('# of Grids per Fiber-Signal', fontsize=24)
+    ax.set_ylabel('Normalized Counts', fontsize=24)
+    fig.savefig('grids_per_fiber.png')
+
+    # build relations between two datasets
+    dfe = dfr.groupby(['event', 'masked_id'])['energy'].sum().reset_index()
+    print(dfe)
+    dfe.rename(columns={'energy':'energy_fibers'}, inplace=True)
+    dfem = dfm[['event', 'merged_id', 'energy']].rename(columns={'merged_id': 'masked_id'})
+    dfe = dfe.merge(dfem, on=['event', 'masked_id'])
+
+    fig, ax = plt.subplots(figsize=(12, 9), dpi=160)
+    ax.scatter(*dfe[['energy_fibers', 'energy']].values.T)
+    ax.tick_params(labelsize=24)
+    ax.set_title('5.0 GeV $e^-$ (100 events)', fontsize=26)
+    ax.set_xlabel('Fiber Group Energy Sum (MeV)', fontsize=24)
+    ax.set_ylabel('Grid Energy (MeV)', fontsize=24)
+    fig.savefig('fiber_grid_energy.png')
+
+
diff --git a/benchmarks/imaging_ecal/scripts/get_layerids.py b/benchmarks/imaging_ecal/scripts/get_layerids.py
index b62871df..a9907b75 100644
--- a/benchmarks/imaging_ecal/scripts/get_layerids.py
+++ b/benchmarks/imaging_ecal/scripts/get_layerids.py
@@ -1,5 +1,9 @@
 '''
-    A simple analysis script to extract some basic info of Monte-Carlo hits
+    An example script showing how to get data from reconstructed root file and decode the cellID column
+    * Commented code uses dd4hep decoder which will take a significant amount of time for initialization
+    ** ReadoutDecoder class solely depends on python xml parser and is thus much faster, but it may not work
+       if some of the features are not supported by the parser (but supported by dd4hep parser)
+    Chao Peng (ANL)
 '''
 import os
 import ROOT
@@ -8,93 +12,7 @@ import numpy as np
 import argparse
 from matplotlib import pyplot as plt
 import matplotlib.ticker as ticker
-from lxml import etree as ET
-
-
-class AthenaDecoder:
-    def __init__(self, compact, readout):
-        self.readouts = self.getReadouts(compact)
-        self.changeReadout(readout)
-
-    def changeReadout(self, readout):
-        self.fieldsmap = self.decomposeIDs(self.readouts[readout])
-
-    def get(self, idvals, field):
-        start, width = self.fieldsmap[field]
-        if width >= 0:
-            return np.bitwise_and(np.right_shift(idvals, start), (1 << width) - 1)
-        # first bit is sign bit
-        else:
-            width = abs(width) - 1
-            vals = np.bitwise_and(np.right_shift(idvals, start), (1 << width) - 1)
-            return np.where(np.bitwise_and(np.right_shift(idvals, start + width), 1), vals - (1 << width), vals)
-
-    def decode(self, idvals):
-        return {field: self.get(idvals, field) for field, _ in self.fieldsmap.items()}
-
-    @staticmethod
-    def getReadouts(path):
-        res = dict()
-        AthenaDecoder.__getReadoutsRecur(path, res)
-        return res
-
-    @staticmethod
-    def __getReadoutsRecur(path, res):
-        if not os.path.exists(path):
-            print('Xml file {} not exist! Ignored it.'.format(path))
-            return
-        lccdd = ET.parse(path).getroot()
-        readouts = lccdd.find('readouts')
-        if readouts is not None:
-            for readout in readouts.getchildren():
-                ids = readout.find('id')
-                if ids is not None:
-                    res[readout.attrib['name']] = ids.text
-        for child in lccdd.getchildren():
-            if child.tag == 'include':
-                root_dir = os.path.dirname(os.path.realpath(path))
-                AthenaDecoder.__getReadoutsRecur(os.path.join(root_dir, child.attrib['ref']), res)
-
-    @staticmethod
-    def decomposeIDs(id_str):
-        res = dict()
-        curr_bit = 0
-        for field_bits in id_str.split(','):
-            elements = field_bits.split(':')
-            field_name = elements[0]
-            bit_width = int(elements[-1])
-            if len(elements) == 3:
-                curr_bit = int(elements[1])
-            res[field_name] = (curr_bit, bit_width)
-            curr_bit += abs(bit_width)
-        return res
-
-
-# read from RDataFrame and flatten a given collection, return pandas dataframe
-def flatten_collection(rdf, collection, cols=None):
-    if not cols:
-        cols = [str(c) for c in rdf.GetColumnNames() if str(c).startswith('{}.'.format(collection))]
-    else:
-        cols = ['{}.{}'.format(collection, c) for c in cols]
-    if not cols:
-        print('cannot find any branch under collection {}'.format(collection))
-        return pd.DataFrame()
-
-    data = rdf.AsNumpy(cols)
-    # flatten the data, add an event id to identify clusters from different events
-    evns = []
-    for i, vec in enumerate(data[cols[0]]):
-        evns += [i]*vec.size()
-    for n, vals in data.items():
-        # make sure ints are not converted to floats
-        typename = vals[0].__class__.__name__.lower()
-        dtype = np.int64 if 'int' in typename or 'long' in typename else np.float64
-        # type safe creation
-        data[n] = np.asarray([v for vec in vals for v in vec], dtype=dtype)
-    # build data frame
-    dfp = pd.DataFrame({c: pd.Series(v) for c, v in data.items()})
-    dfp.loc[:, 'event'] = evns
-    return dfp
+from utils import ReadoutDecoder, flatten_collection
 
 
 if __name__ == '__main__':
@@ -110,7 +28,7 @@ if __name__ == '__main__':
     args = parser.parse_args()
 
     # decoder
-    decoder = AthenaDecoder(args.compact, args.readout)
+    decoder = ReadoutDecoder(args.compact, args.readout)
 
     # get hits
     rdf_rec = ROOT.RDataFrame('events', args.rec_file)
@@ -131,6 +49,6 @@ if __name__ == '__main__':
 
     # faster way to get layerids
     df.loc[:, 'layerID'] = decoder.get(df['cellID'].values, 'layer')
-    df.loc[:, 'xID'] = decoder.get(df['cellID'].values, 'x')
-    print(df[['cellID', 'layerID', 'xID', 'position.x', 'position.y', 'position.z', 'energy']])
+    df.loc[:, 'zID'] = decoder.get(df['cellID'].values, 'z')
+    print(df[['cellID', 'layerID', 'zID', 'position.x', 'position.y', 'position.z', 'energy']])
 
diff --git a/benchmarks/imaging_ecal/scripts/utils.py b/benchmarks/imaging_ecal/scripts/utils.py
index efdc8424..2ac8d1ef 100644
--- a/benchmarks/imaging_ecal/scripts/utils.py
+++ b/benchmarks/imaging_ecal/scripts/utils.py
@@ -16,8 +16,104 @@ import pandas as pd
 import matplotlib
 import DDG4
 from ROOT import gROOT, gInterpreter
+from lxml import etree as ET
 
 
+class ReadoutDecoder:
+    def __init__(self, compact, readout):
+        self.readouts = self.getReadouts(compact)
+        self.changeReadout(readout)
+
+    def changeReadout(self, readout):
+        self.fieldsmap = self.decomposeIDs(self.readouts[readout])
+
+    def get(self, idvals, field):
+        start, width = self.fieldsmap[field]
+        if width >= 0:
+            return np.bitwise_and(np.right_shift(idvals, start), (1 << width) - 1)
+        # first bit is sign bit
+        else:
+            width = abs(width) - 1
+            vals = np.bitwise_and(np.right_shift(idvals, start), (1 << width) - 1)
+            return np.where(np.bitwise_and(np.right_shift(idvals, start + width), 1), vals - (1 << width), vals)
+
+    def mask(self, field):
+        start, width = self.fieldsmap[field]
+        return np.uint64((2**abs(width) - 1) << start)
+
+    def decode(self, idvals):
+        return {field: self.get(idvals, field) for field, _ in self.fieldsmap.items()}
+
+    @staticmethod
+    def getReadouts(path):
+        res = dict()
+        ReadoutDecoder.__getReadoutsRecur(path, res)
+        return res
+
+    @staticmethod
+    def __getReadoutsRecur(path, res):
+        if not os.path.exists(path):
+            print('Xml file {} not exist! Ignored it.'.format(path))
+            return
+        lccdd = ET.parse(path).getroot()
+        readouts = lccdd.find('readouts')
+        if readouts is not None:
+            for readout in readouts.getchildren():
+                ids = readout.find('id')
+                if ids is not None:
+                    res[readout.attrib['name']] = ids.text
+        for child in lccdd.getchildren():
+            if child.tag == 'include':
+                root_dir = os.path.dirname(os.path.realpath(path))
+                ReadoutDecoder.__getReadoutsRecur(os.path.join(root_dir, child.attrib['ref']), res)
+
+    @staticmethod
+    def decomposeIDs(id_str):
+        res = dict()
+        curr_bit = 0
+        for field_bits in id_str.split(','):
+            elements = field_bits.split(':')
+            field_name = elements[0]
+            bit_width = int(elements[-1])
+            if len(elements) == 3:
+                curr_bit = int(elements[1])
+            res[field_name] = (curr_bit, bit_width)
+            curr_bit += abs(bit_width)
+        return res
+
+
+# read from RDataFrame and flatten a given collection, return pandas dataframe
+def flatten_collection(rdf, collection, cols=None):
+    if not cols:
+        cols = [str(c) for c in rdf.GetColumnNames() if str(c).startswith('{}.'.format(collection))]
+    else:
+        cols = ['{}.{}'.format(collection, c) for c in cols]
+    if not cols:
+        print('cannot find any branch under collection {}'.format(collection))
+        return pd.DataFrame()
+    # print(rdf.GetColumnNames())
+    data = rdf.AsNumpy(cols)
+    # flatten the data, add an event id to identify clusters from different events
+    evns = []
+    for i, vec in enumerate(data[cols[0]]):
+        evns += [i]*vec.size()
+    for n, vals in data.items():
+        # make sure ints are not converted to floats
+        typename = vals[0].__class__.__name__.lower()
+        # default
+        dtype = np.float64
+        if 'unsigned int' in typename or 'unsigned long' in typename:
+            dtype = np.uint64
+        elif 'int' in typename or 'long' in typename:
+            dtype = np.int64
+        # print(n, typename, dtype)
+        # type safe creation
+        data[n] = np.asarray([v for vec in vals for v in vec], dtype=dtype)
+    # build data frame
+    dfp = pd.DataFrame({c: pd.Series(v) for c, v in data.items()})
+    dfp.loc[:, 'event'] = evns
+    return dfp
+
 # helper function to truncate color map (for a better view from the rainbow colormap)
 def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
     new_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
@@ -106,7 +202,7 @@ def get_all_mcp(path, evnums=None, branch='MCParticles'):
         for ptl in getattr(events, branch):
             dbuf[idb] = (iev, ptl.momentum.x, ptl.momentum.y, ptl.momentum.z, ptl.PDG, ptl.simulatorStatus, ptl.endpoint.x, ptl.endpoint.y, ptl.endpoint.z)
             idb += 1
-    
+
     return pd.DataFrame(data=dbuf[:idb], columns=['event', 'px', 'py', 'pz', 'pid', 'status', 'vex', 'vey', 'vez'])
 
 # read hits data from root file
-- 
GitLab