'''
    A script to calcualte placement of ecal endcap modules
    lxml is not included in container, get it by simply typing 'pip install lxml'
    Author: Chao Peng (ANL)
    Date: 06/17/2021
'''

import numpy as np
import argparse
from lxml import etree as ET


# constants: name, value
CONSTANTS = [
    ('CrystalModule_sx', '20.0*mm'),
    ('CrystalModule_sy', '20.0*mm'),
    ('CrystalModule_sz', '20.0*cm'),
    ('CrystalModule_wrap', '0.5*mm'),
    ('GlassModule_sx', '40.0*mm'),
    ('GlassModule_sy', '40.0*mm'),
    ('GlassModule_sz', '40.0*cm'),
    ('GlassModule_wrap', '1.0*mm'),
    ('CrystalModule_z0', '10.*cm'),
    ('GlassModule_z0', '0.0*cm'),
    ('EcalEndcapN_z0', '-EcalEndcapN_zmin-max(CrystalModule_sz,GlassModule_sz)/2.'),
    ('CrystalModule_dx', 'CrystalModule_sx + CrystalModule_wrap'),
    ('CrystalModule_dy', 'CrystalModule_sy + CrystalModule_wrap'),
    ('GlassModule_dx', 'GlassModule_sx + GlassModule_wrap'),
    ('GlassModule_dy', 'GlassModule_sy + GlassModule_wrap'),
]

# line-by-line alignment start pos, total number of blocks
CRYSTAL_ALIGNMENT = [
    (7, 17), (7, 17), (7, 17), (6, 18),
    (6, 18), (5, 19), (3, 19), (0, 22),
    (0, 22), (0, 22), (0, 22), (0, 22),
    (0, 20), (0, 20), (0, 18), (0, 18),
    (0, 16), (0, 16), (0, 14), (0, 14),
    (0, 12), (0, 12), (0, 6),  (0, 6),
]

GLASS_ALIGNMENT = [
    (12, 11), (12, 11), (12, 11), (11, 12),
    (11, 12), (11, 12), (10, 12), (9, 13),
    (8, 13),  (7, 14),  (6, 14),  (3, 16),
    (0, 19),  (0, 18),  (0, 18),  (0, 17),
    (0, 17),  (0, 15),  (0, 13),  (0, 11),
    (0, 10),  (0, 8),   (0, 6),
]

# calculate positions of modules with a quad-alignment and module size
def individual_placement(alignment, module_x=20.5, module_y=20.5):
    placements = []
    for row, (start, num) in enumerate(alignment):
        for col in np.arange(start, start + num):
            placements.append(((col + 0.5)*module_y, (row + 0.5)*module_x))
    placements = np.asarray(placements)
    return np.vstack((placements,
            np.vstack((placements.T[0]*-1., placements.T[1])).T,
            np.vstack((placements.T[0], placements.T[1]*-1.)).T,
            np.vstack((placements.T[0]*-1., placements.T[1]*-1.)).T))



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--save', default='compact/ce_ecal_crystal_glass.xml',
            help='path to save compact file.')
    parser.add_argument('--individual', dest='indiv', action='store_true',
            help='individual block placements instead of line placements')
    args = parser.parse_args()

    data = ET.Element('lccdd')
    defines = ET.SubElement(data, 'define')

    for name, value in CONSTANTS:
        constant = ET.SubElement(defines, 'constant')
        constant.set('name', name)
        constant.set('value', value)

    # this value will be used multiple times, so define it here
    readout_name = 'EcalEndcapNHits'

    # detector and its dimension/position/rotation
    dets = ET.SubElement(data, 'detectors')
    cmt = ET.SubElement(dets, 'comment')
    cmt.text = ' Backwards Endcap EM Calorimeter, placements generated by script '

    det = ET.SubElement(dets, 'detector')
    det.set('id', 'ECalEndcapN_ID')
    det.set('name', 'EcalEndcapN')
    det.set('type', 'HomogeneousCalorimeter')
    det.set('readout', readout_name)

    pos = ET.SubElement(det, 'position')
    pos.set('x', '0')
    pos.set('y', '0')
    pos.set('z', 'EcalEndcapN_z0')

    rot = ET.SubElement(det, 'rotation')
    rot.set('x', '0')
    rot.set('y', '0')
    rot.set('z', '0')

    # placements of modules
    plm = ET.SubElement(det, 'placements')
    pltype = 'individuals' if args.indiv else 'lines'

    # crystal
    crystal = ET.SubElement(plm, pltype)
    crystal.set('sector', '1')
    crystal_mod = ET.SubElement(crystal, 'module')
    crystal_mod.set('sizex', 'CrystalModule_sx')
    crystal_mod.set('sizey', 'CrystalModule_sy')
    crystal_mod.set('sizez', 'CrystalModule_sz')
    crystal_mod.set('material', 'PbWO4')
    crystal_mod.set('vis', 'AnlTeal')
    crystal_wrap = ET.SubElement(crystal, 'wrapper')
    crystal_wrap.set('thickness', 'CrystalModule_wrap')
    crystal_wrap.set('material', 'Epoxy')
    crystal_wrap.set('vis', 'WhiteVis')
    # crystal placements (for individuals)
    if args.indiv:
        for m, (x, y) in enumerate(individual_placement(CRYSTAL_ALIGNMENT)):
            module = ET.SubElement(crystal, 'placement')
            module.set('x', '{:.3f}*mm'.format(x))
            module.set('y', '{:.3f}*mm'.format(y))
            module.set('z', 'CrystalModule_z0')
            module.set('id', '{:d}'.format(m))
    # crystal placements (for lines)
    else:
        crystal.set('mirrorx', 'true')
        crystal.set('mirrory', 'true')
        for row, (begin, nmods) in enumerate(CRYSTAL_ALIGNMENT):
            line = ET.SubElement(crystal, 'line')
            line.set('axis', 'x')
            line.set('x', 'CrystalModule_dx/2.')
            line.set('y', 'CrystalModule_dy*{:d}/2.'.format(row*2 + 1))
            line.set('z', 'CrystalModule_z0')
            line.set('begin', '{:d}'.format(begin))
            line.set('nmods', '{:d}'.format(nmods))


    # glass
    glass = ET.SubElement(plm, pltype)
    glass.set('sector', '2')
    glass_mod = ET.SubElement(glass, 'module')
    glass_mod.set('sizex', 'GlassModule_sx')
    glass_mod.set('sizey', 'GlassModule_sy')
    glass_mod.set('sizez', 'GlassModule_sz')
    # TODO: change glass material
    glass_mod.set('material', 'PbGlass')
    glass_mod.set('vis', 'AnlBlue')
    glass_wrap = ET.SubElement(glass, 'wrapper')
    glass_wrap.set('thickness', 'GlassModule_wrap')
    glass_wrap.set('material', 'Epoxy')
    glass_wrap.set('vis', 'WhiteVis')
    # crystal placements (for individuals)
    if args.indiv:
        for m, (x, y) in enumerate(individual_placement(GLASS_ALIGNMENT, 41.0, 41.0)):
            module = ET.SubElement(glass, 'placement')
            module.set('x', '{:.3f}*mm'.format(x))
            module.set('y', '{:.3f}*mm'.format(y))
            module.set('z', 'GlassModule_z0')
            module.set('id', '{:d}'.format(m))
    # crystal placements (for lines)
    else:
        glass.set('mirrorx', 'true')
        glass.set('mirrory', 'true')
        for row, (begin, nmods) in enumerate(GLASS_ALIGNMENT):
            line = ET.SubElement(glass, 'line')
            line.set('axis', 'x')
            line.set('x', 'GlassModule_dx/2.')
            line.set('y', 'GlassModule_dy*{:d}/2.'.format(row*2 + 1))
            line.set('z', 'GlassModule_z0')
            line.set('begin', '{:d}'.format(begin))
            line.set('nmods', '{:d}'.format(nmods))


    # readout
    readouts = ET.SubElement(data, 'readouts')
    cmt = ET.SubElement(readouts, 'comment')
    cmt.text = 'Effectively no segmentation, the segmentation is used to provide cell dimension info'
    readout = ET.SubElement(readouts, 'readout')
    readout.set('name', readout_name)
    seg = ET.SubElement(readout, 'segmentation')
    # need segmentation to provide cell dimension info
    # seg.set('type', 'NoSegmentation')
    seg.set('type', 'MultiSegmentation')
    seg.set('key', 'sector')
    crystal_seg = ET.SubElement(seg, 'segmentation')
    crystal_seg.set('name', 'CrystalSeg')
    crystal_seg.set('key_value', '1')
    crystal_seg.set('type', 'CartesianGridXY')
    crystal_seg.set('grid_size_x', 'CrystalModule_dx')
    crystal_seg.set('grid_size_y', 'CrystalModule_dy')
    glass_seg = ET.SubElement(seg, 'segmentation')
    glass_seg.set('name', 'GlassSeg')
    glass_seg.set('key_value', '2')
    glass_seg.set('type', 'CartesianGridXY')
    glass_seg.set('grid_size_x', 'GlassModule_dx')
    glass_seg.set('grid_size_y', 'GlassModule_dy')
    rid = ET.SubElement(readout, 'id')
    rid.text = 'system:8,sector:4,module:20,x:32:-16,y:-16'


    text = ET.tostring(data, pretty_print=True)
    with open(args.save, 'wb') as f:
        f.write(text)