From 1c377c8f769fbbe001c7b811ef6583e29493221f Mon Sep 17 00:00:00 2001 From: Chao Peng <cpeng@anl.gov> Date: Mon, 13 Dec 2021 08:53:31 -0600 Subject: [PATCH] add a script to draw edep --- .../scripts/check_edep_dists.py | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 benchmarks/imaging_shower_ML/scripts/check_edep_dists.py diff --git a/benchmarks/imaging_shower_ML/scripts/check_edep_dists.py b/benchmarks/imaging_shower_ML/scripts/check_edep_dists.py new file mode 100644 index 00000000..404855b1 --- /dev/null +++ b/benchmarks/imaging_shower_ML/scripts/check_edep_dists.py @@ -0,0 +1,135 @@ +import ROOT +import os +import gc +import numpy as np +import pandas as pd +import argparse +import sys +import matplotlib.pyplot as plt + + +# read from RDataFrame and flatten a given collection, return pandas dataframe +def flatten_collection(rdf, collection, cols=None, event_colname='event'): + if not cols: + cols = [str(c) for c in rdf.GetColumnNames() if str(c).startswith('{}.'.format(collection))] + else: + cols = ['{}.{}'.format(collection, c) for c in cols] + if not cols: + print('cannot find any branch under collection {}'.format(collection)) + return pd.DataFrame() + + data = rdf.AsNumpy(cols) + # flatten the data, add an event id to identify clusters from different events + evns = [] + for i, vec in enumerate(data[cols[0]]): + evns += [i]*vec.size() + for n, vals in data.items(): + # make sure ints are not converted to floats + typename = vals[0].__class__.__name__.lower() + dtype = np.int64 if 'int' in typename or 'long' in typename else np.float32 + # type safe creation + data[n] = np.asarray([v for vec in vals for v in vec], dtype=dtype) + # build data frame + dfp = pd.DataFrame({c: pd.Series(v) for c, v in data.items()}) + dfp.loc[:, event_colname] = evns + return dfp + + +# load root macros, input is an argument string +def load_root_macros(arg_macros): + for path in arg_macros.split(','): + path = path.strip() + if os.path.exists(path): + ROOT.gROOT.Macro(path) + else: + print('\"{}\" does not exist, skip loading it.'.format(path)) + + +def cartesian_to_polar(x, y, z): + r = np.sqrt(x**2 + y**2 + z**2) + rc = np.sqrt(x**2 + y**2) + theta = np.arccos(z / r) + phi = np.arctan2(y, x) + eta = -np.log(np.tan(theta / 2.)) + return r, theta, phi, rc, eta + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('file', type=str, help='path to root file') + parser.add_argument('--sim', action='store_true', help='flag switch between sim and rec data (default rec)') + parser.add_argument('-o', type=str, default='.', dest='outdir', help='output directory') + # parser.add_argument('--compact', type=str, default='', dest='compact', help='compact file') + parser.add_argument('-m', '--macros', type=str, default='rootlogon.C', dest='macros', + help='root macros to load (accept multiple paths separated by \",\")') + parser.add_argument('--branch', type=str, default='EcalBarrelImagingHitsReco', help='name of data branch (eic::CalorimeterHitCollection)') + parser.add_argument('--truth-branch', type=str, default='mcparticles', help='name of truth mc branch') + parser.add_argument('--edep-max', type=float, default=0., help='maximum edep (GeV) to plot') + parser.add_argument('--edep-nbins', type=int, default=200, help='number of bins') + parser.add_argument('--name-tag', type=str, default='test', help='name tag to save the file') + parser.add_argument('--samp-frac', type=float, default=1.0, help='sampling fraction') + args = parser.parse_args() + + os.makedirs(args.outdir, exist_ok=True) + load_root_macros(args.macros) + + # read data and mcparticles + rdf = ROOT.RDataFrame("events", args.file) + + mc_branch = args.truth_branch + dfm = flatten_collection(rdf, mc_branch, ['genStatus', 'pdgID', 'ps.x', 'ps.y', 'ps.z', 'mass']) + dfm.rename(columns={c: c.replace(mc_branch + '.', '') for c in dfm.columns}, inplace=True) + # selete incident particles + dfm = dfm[dfm['genStatus'].isin([0, 1])] + # NOTE: assumed single particles + dfm = dfm.groupby('event').first() + # p, theta, phi, pT, eta = cartesian_to_polar(*dfm[['ps.x', 'ps.y', 'ps.z']].values.T) + + if args.sim: + df = flatten_collection(rdf, args.branch, ['energyDeposit']) + df.rename(columns={c: c.replace(args.branch + '.', '') for c in df.columns}, inplace=True) + df.rename(columns={'energyDeposit': 'energy'}, inplace=True) + else: + df = flatten_collection(rdf, args.branch, ['layer', 'energy', 'position.x', 'position.y', 'position.z']) + df.rename(columns={c: c.replace(args.branch + '.', '') for c in df.columns}, inplace=True) + + dfe = df.groupby('event')['energy'].sum().reset_index() + # determine histrogram bins + if args.edep_max <= 0.: + args.edep_max = dfe['energy'].quantile(0.99)*1.2 + bins = np.linspace(0., args.edep_max, args.edep_nbins + 1) + bincenters = (bins[1:] + bins[:-1])/2. + + # get particle types + fig, ax = plt.subplots(figsize=(16, 9), dpi=120, gridspec_kw={'left': 0.15, 'right': 0.95}) + ax.set_xlabel('Energy Deposit / {:.2f} (GeV)'.format(args.samp_frac), fontsize=24) + ax.set_ylabel('Normalized Counts', fontsize=24) + ax.set_yscale('log') + ax.grid(linestyle=':') + ax.tick_params(labelsize=24) + ax.set_axisbelow(True) + + hist_vals, hist_cols = [], [] + pdgbase = ROOT.TDatabasePDG() + for pdgid in dfm['pdgID'].unique(): + particle = pdgbase.GetParticle(int(pdgid)) + if not particle: + print("Unknown pdgcode {}, they are ignored".format(int(pdgid))) + continue + events_indices = dfm[dfm.loc[:, 'pdgID'] == pdgid].index.unique() + print("{} entries of particle {}".format(len(events_indices), particle.GetName())) + + dfe_part = dfe.loc[dfe['event'].isin(events_indices)] + + edep_vals, _, _ = ax.hist(dfe_part['energy'] / args.samp_frac, + weights = [1. / dfe_part.shape[0]]*dfe_part.shape[0], + histtype='step', bins=bins, label=particle.GetName()) + hist_vals.append(edep_vals) + hist_cols.append(particle.GetName()) + + pd.DataFrame(index=bincenters, data=np.vstack(hist_vals).T, columns=hist_cols)\ + .to_csv('{}.csv'.format(args.name_tag)) + ax.legend(fontsize=24) + fig.savefig('{}.png'.format(args.name_tag)) + + -- GitLab