From 62751498cecea54e8238c67234cc5c6d795e661f Mon Sep 17 00:00:00 2001
From: Chao Peng <cpeng@anl.gov>
Date: Sat, 17 Jul 2021 17:45:37 -0500
Subject: [PATCH] improve some functions in analysis scripts

---
 benchmarks/clustering/scripts/cluster_plots.py | 13 +++++++++++--
 benchmarks/ecal/scripts/draw_cluters.py        | 12 ++++++++++--
 2 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/benchmarks/clustering/scripts/cluster_plots.py b/benchmarks/clustering/scripts/cluster_plots.py
index 8d4ec745..e2a53cfd 100644
--- a/benchmarks/clustering/scripts/cluster_plots.py
+++ b/benchmarks/clustering/scripts/cluster_plots.py
@@ -27,15 +27,23 @@ def flatten_collection(rdf, collection, cols=None):
         cols = [str(c) for c in rdf.GetColumnNames() if str(c).startswith('{}.'.format(collection))]
     else:
         cols = ['{}.{}'.format(collection, c) for c in cols]
+    if not cols:
+        print('cannot find any branch under collection {}'.format(collection))
+        return pd.DataFrame()
+
     data = rdf.AsNumpy(cols)
     # flatten the data, add an event id to identify clusters from different events
     evns = []
     for i, vec in enumerate(data[cols[0]]):
         evns += [i]*vec.size()
     for n, vals in data.items():
-        data[n] = np.asarray([v for vec in vals for v in vec])
+        # make sure ints are not converted to floats
+        typename = vals[0].__class__.__name__.lower()
+        dtype = np.int64 if 'int' in typename or 'long' in typename else np.float64
+        # type safe creation
+        data[n] = np.asarray([v for vec in vals for v in vec], dtype=dtype)
     # build data frame
-    dfp = pd.DataFrame(columns=cols, data=np.vstack(list(data.values())).T)
+    dfp = pd.DataFrame({c: pd.Series(v) for c, v in data.items()})
     dfp.loc[:, 'event'] = evns
     return dfp
 
@@ -152,6 +160,7 @@ if __name__ == '__main__':
         # calculate eta
         if 'eta' not in df.columns:
             df.loc[:, 'eta'] = -np.log(np.tan(df['polar.theta'].values/2.))
+        # print(df[['eta', 'polar.theta', 'position.x', 'position.y', 'position.z']])
         fig, axs = plt.subplots(2, 2, figsize=(12, 8), dpi=160)
         ncl = df.groupby('event')['clusterID'].nunique().values
         axs[0][0].hist(ncl, weights=np.repeat(1./float(ncl.shape[0]), ncl.shape[0]),
diff --git a/benchmarks/ecal/scripts/draw_cluters.py b/benchmarks/ecal/scripts/draw_cluters.py
index a346731d..fd03f2dc 100644
--- a/benchmarks/ecal/scripts/draw_cluters.py
+++ b/benchmarks/ecal/scripts/draw_cluters.py
@@ -28,15 +28,23 @@ def flatten_collection(rdf, collection, cols=None):
         cols = [str(c) for c in rdf.GetColumnNames() if str(c).startswith('{}.'.format(collection))]
     else:
         cols = ['{}.{}'.format(collection, c) for c in cols]
+    if not cols:
+        print('cannot find any branch under collection {}'.format(collection))
+        return pd.DataFrame()
+
     data = rdf.AsNumpy(cols)
     # flatten the data, add an event id to identify clusters from different events
     evns = []
     for i, vec in enumerate(data[cols[0]]):
         evns += [i]*vec.size()
     for n, vals in data.items():
-        data[n] = np.asarray([v for vec in vals for v in vec])
+        # make sure ints are not converted to floats
+        typename = vals[0].__class__.__name__.lower()
+        dtype = np.int64 if 'int' in typename or 'long' in typename else np.float64
+        # type safe creation
+        data[n] = np.asarray([v for vec in vals for v in vec], dtype=dtype)
     # build data frame
-    dfp = pd.DataFrame(columns=cols, data=np.vstack(list(data.values())).T)
+    dfp = pd.DataFrame({c: pd.Series(v) for c, v in data.items()})
     dfp.loc[:, 'event'] = evns
     return dfp
 
-- 
GitLab