From b16f46e02fb37e8f9047db8bae5c8b90b6667c35 Mon Sep 17 00:00:00 2001
From: Chao Peng <cpeng@anl.gov>
Date: Fri, 18 Jun 2021 07:42:15 -0500
Subject: [PATCH] add dataframe return for the script

---
 benchmarks/clustering/scripts/cluster_plots.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/benchmarks/clustering/scripts/cluster_plots.py b/benchmarks/clustering/scripts/cluster_plots.py
index 7a201a7f..9d0bd633 100644
--- a/benchmarks/clustering/scripts/cluster_plots.py
+++ b/benchmarks/clustering/scripts/cluster_plots.py
@@ -91,8 +91,8 @@ def general_clusters_figure(df, collection, save, min_nhits=3):
             data=np.vstack(list(data.values())).T)
     dfp.loc[:, 'evn'] = evns
     # select the max. energy cluster for each event
-    dfp = dfp.loc[dfp.groupby('evn')['edep'].idxmax()]
-    dfp = dfp.loc[dfp['nhits'] >= min_nhits]
+    dfc = dfp.loc[dfp.groupby('evn')['edep'].idxmax()]
+    dfc = dfc.loc[dfc['nhits'] >= min_nhits]
     # figure
     fig, axs = plt.subplots(2, 2, figsize=(16, 12), dpi=120)
     labels = [
@@ -101,7 +101,7 @@ def general_clusters_figure(df, collection, save, min_nhits=3):
         (r'$\theta$ (rad)', 'Counts'),
         (r'$\phi$ (rad)', 'Counts'),
     ]
-    for ax, label, vals in zip(axs.flat, labels, dfp[['nhits', 'edep', 'theta', 'phi']].values.T):
+    for ax, label, vals in zip(axs.flat, labels, dfc[['nhits', 'edep', 'theta', 'phi']].values.T):
         ax.hist(vals, bins=50, ec='k')
         ax.tick_params(labelsize=22, direction='in', which='both')
         ax.grid(linestyle=':', which='both')
@@ -112,6 +112,7 @@ def general_clusters_figure(df, collection, save, min_nhits=3):
     fig.text(0.5, 0.95, collection, ha='center', fontsize=24)
     fig.savefig(save)
     plt.close(fig)
+    return dfp
 
 
 if __name__ == '__main__':
@@ -143,12 +144,14 @@ if __name__ == '__main__':
     rdf_rec = ROOT.RDataFrame('events', args.rec_file)
 
     thrown_particles_figure(rdf_sim, save=os.path.join(args.outdir, 'thrown_particles.png'), mcbranch=args.mc)
+
     general_clusters_figure(rdf_rec, collection='EcalEndcapNClusters', min_nhits=10,
             save=os.path.join(args.outdir, 'ecal_electron_endcap_clusters.png'))
     general_clusters_figure(rdf_rec, collection='EcalEndcapPClusters', min_nhits=5,
             save=os.path.join(args.outdir, 'ecal_hadron_endcap_clusters.png'))
     general_clusters_figure(rdf_rec, collection='EcalBarrelClusters',
             save=os.path.join(args.outdir, 'ecal_barrel_clusters.png'))
+
     general_clusters_figure(rdf_rec, collection='HcalElectronEndcapClusters', min_nhits=5,
             save=os.path.join(args.outdir, 'hcal_electron_endcap_clusters.png'))
     general_clusters_figure(rdf_rec, collection='HcalHadronEndcapClusters', min_nhits=10,
-- 
GitLab