Commit 89d8ce4d authored by David Blyth's avatar David Blyth

Saving state

parent 8d2388dc
......@@ -35,6 +35,17 @@ void printBFields() {
}
}
genfit::AbsMeasurement *makeIPMeasurement() {
TVectorD hitCoords(3);
hitCoords = 0;
TMatrixDSym hitCov(3);
hitCov = 0;
hitCov[0][0] = 10;
hitCov[1][1] = 10;
hitCov[2][2] = 10;
return new genfit::SpacepointMeasurement(hitCoords, hitCov, 0, 0, nullptr);
}
genfit::AbsMeasurement *makeMeasurement(eic::EnergyDep *entry) {
if (entry->mean() / entry->noise() < 2) return NULL;
......@@ -143,13 +154,14 @@ void checkMetadata(proio::Event *event) {
}
}
void quickTrackProcess(genfit::KalmanFitter *fitter, genfit::Track *track) {
void quickTrackProcess(genfit::KalmanFitter *fitter, genfit::Track *track, int n_passes = 1) {
for (auto rep : track->getTrackReps()) {
fitter->processTrackPartially(track, rep, 0, -1);
track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1);
track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1);
for (int i = 0; i < n_passes; i++) {
track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1);
track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1);
}
}
}
......@@ -207,84 +219,97 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter
auto measurement = measurements[track_obs[0]];
TVector3 ip_pos(0, 0, 0);
auto meas_coords = measurement->getRawHitCoords();
TVector3 meas_pos(0.01 * meas_coords[0], 0.01 * meas_coords[1], 0.01 * meas_coords[2]);
TVector3 meas_pos(10 * meas_coords[0], 10 * meas_coords[1], 10 * meas_coords[2]);
auto kernel_sort_fn = [&](const uint64_t a_id, const uint64_t b_id) {
auto a_coords = measurements[a_id]->getRawHitCoords();
TVector3 a_pos(a_coords[0], a_coords[1], a_coords[2]);
auto b_coords = measurements[b_id]->getRawHitCoords();
TVector3 b_pos(b_coords[0], b_coords[1], b_coords[2]);
return meas_pos.Dot(a_pos) / a_pos.Mag() > meas_pos.Dot(b_pos) / b_pos.Mag();
return meas_pos.Dot(a_pos) / a_pos.Mag() < meas_pos.Dot(b_pos) / b_pos.Mag();
};
std::sort(available_obs.begin(), available_obs.end(), kernel_sort_fn);
auto track = new genfit::Track;
track->setStateSeed(ip_pos, meas_pos);
track->insertMeasurement(makeIPMeasurement());
track->getPoint(0)->setSortingParameter(0.0);
track->addTrackRep(new genfit::RKTrackRep(11));
track->addTrackRep(new genfit::RKTrackRep(13));
int n_seed = 5;
for (int i = 1; i < n_seed; i++) track_obs.push_back(available_obs[i - 1]);
int n_seed = 3;
for (int i = 0; i < n_seed - 1; i++) {
track_obs.push_back(available_obs.back());
available_obs.pop_back();
}
std::map<genfit::TrackPoint *, uint64_t> id_lookup;
for (auto id : track_obs) {
track->insertMeasurement(measurements[id]->clone());
meas_coords = measurements[id]->getRawHitCoords();
std::cout << meas_coords[0] << "\t" << meas_coords[1] << "\t" << meas_coords[2] << std::endl;
track->getPoint(-1)->setSortingParameter(seed_sort_param[id]);
auto track_point = track->getPoint(-1);
track_point->setSortingParameter(seed_sort_param[id]);
id_lookup[track_point] = id;
}
available_obs.erase(available_obs.begin(), available_obs.begin() + n_seed - 1);
track->sort();
quickTrackProcess(fitter, track);
fitter->processTrack(track);
bool do_extend = false;
for (auto rep : track->getTrackReps())
if (track->getFitStatus(rep)->isFitConverged()) {
do_extend = true;
break;
}
for (auto iter = available_obs.begin(); iter != available_obs.end();) {
auto measurement = measurements[*iter];
// extend track
while (do_extend && available_obs.size() > 0) {
meas_coords = track->getPoint(-1)->getRawMeasurement()->getRawHitCoords();
meas_pos = TVector3(0.01 * meas_coords[0], 0.01 * meas_coords[1], 0.01 * meas_coords[2]);
std::sort(available_obs.begin(), available_obs.end(), kernel_sort_fn);
auto id = available_obs.back();
track_obs.push_back(id);
available_obs.pop_back();
auto measurement = measurements[id];
auto coords = measurement->getRawHitCoords();
std::cout << coords[0] << "\t" << coords[1] << "\t" << coords[2] << std::endl;
track->insertMeasurement(measurements[*iter]->clone());
track->insertMeasurement(measurement->clone());
auto track_point = track->getPoint(-1);
track_point->setSortingParameter(seed_sort_param[*iter]);
id_lookup[track_point] = id;
track_point->setSortingParameter(seed_sort_param[id]);
track->sort();
track_obs.push_back(*iter);
iter = available_obs.erase(iter);
quickTrackProcess(fitter, track);
std::vector<genfit::AbsTrackRep *> all_reps(track->getTrackReps());
std::vector<genfit::AbsTrackRep *> bad_reps;
for (auto rep : all_reps) {
fitter->processTrackPartially(track, rep, -2, -1);
double pval = forwardPVal(track, rep);
std::cout << "pval for pdg code " << rep->getPDG() << ": " << pval << std::endl;
if (fabs(pval - 0.5) < 0.49) continue;
bad_reps.push_back(rep);
}
if (bad_reps.size() == all_reps.size()) {
std::cout << "breaking extension" << std::endl;
for (int i = 0; i < track->getNumPoints(); i++) {
for (int i = 1; i < track->getNumPoints(); i++) {
if (track_point == track->getPoint(i)) {
track->deletePoint(i);
break;
}
}
available_obs.push_back(*iter);
available_obs.push_back(id);
track_obs.pop_back();
break;
}
for (auto rep : bad_reps) track->deleteTrackRep(track->getIdForRep(rep));
quickTrackProcess(fitter, track);
}
track->determineCardinalRep();
auto rep = track->getCardinalRep();
// fit and prune track
std::cout << "track_obs.size(): " << track_obs.size() << std::endl;
while (track_obs.size() >= n_validate) {
auto rep = track->getCardinalRep();
// fit track
if (track_obs.size() >= n_validate) {
fitter->processTrackWithRep(track, rep);
if (track->getFitStatus(rep)->isFitConverged()) {
auto eic_track = new eic::Track;
try {
std::cout << "successful fit" << std::endl;
auto fieldMgr = genfit::FieldManager::getInstance();
for (unsigned int id = 0; id < track->getNumPoints() - 1; id++) {
auto point = track->getPointWithMeasurement(id);
auto info = point->getFitterInfo(rep);
auto fieldMgr = genfit::FieldManager::getInstance();
for (unsigned int id = 0; id < track->getNumPoints() - 1; id++) {
try {
auto point = track->getPoint(id);
auto info = static_cast<genfit::KalmanFitterInfo *>(point->getFitterInfo(rep));
auto state = info->getFittedState();
TVector3 pos = rep->getPos(state);
TVector3 mom = rep->getMom(state);
......@@ -310,57 +335,39 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter
magfield->set_z(field.z() / 10);
segment->set_chargesign(charge);
segment->set_length(length * 10);
} catch (genfit::Exception e) {
std::cerr << e.what() << std::endl;
}
auto entryID = event->AddEntry(eic_track, "Tracker");
event->TagEntry(entryID, "Reconstructed");
event->TagEntry(entryID, "Vis");
} catch (const genfit::Exception e) {
std::cout << e.what() << std::endl;
}
auto entryID = event->AddEntry(eic_track, "Tracker");
event->TagEntry(entryID, "Reconstructed");
event->TagEntry(entryID, "Vis");
track_obs.clear();
break;
} else {
std::cout << "failed to converge" << std::endl;
double max_chi2_inc = 0;
int max_chi2_inc_i = 0;
for (int i = 0; i < track_obs.size(); i++) {
auto fi =
static_cast<genfit::KalmanFitterInfo *>(track->getPoint(i)->getKalmanFitterInfo(rep));
if (!fi) {
max_chi2_inc_i = i;
break;
}
auto update = fi->getUpdate(1);
if (!update) {
max_chi2_inc_i = i;
break;
}
auto chi2_inc = update->getChiSquareIncrement();
// auto chi2_inc = fi->getSmoothedChi2();
// if (chi2_inc > max_chi2_inc) {
// max_chi2_inc = chi2_inc;
// max_chi2_inc_i = i;
//}
}
if (max_chi2_inc_i == 0) {
std::cout << "breaking fit" << std::endl;
break;
}
}
}
track_obs.erase(track_obs.begin() + max_chi2_inc_i);
track->deletePoint(max_chi2_inc_i);
// remove largest outlying hit from track
double max_chi2_inc = 0;
int max_chi2_inc_i = 1;
for (int i = 1; i < track_obs.size() + 1; i++) {
auto fi = static_cast<genfit::KalmanFitterInfo *>(track->getPoint(i)->getKalmanFitterInfo(rep));
if (!fi) {
max_chi2_inc_i = i;
break;
}
auto update = fi->getUpdate(1);
if (!update) {
max_chi2_inc_i = i;
break;
}
auto chi2_inc = update->getChiSquareIncrement();
}
auto max_chi2_inc_id = id_lookup[track->getPoint(max_chi2_inc_i)];
track_obs.erase(std::remove(track_obs.begin(), track_obs.end(), max_chi2_inc_id), track_obs.end());
// return unconsumed hits to available list
// std::sort(available_obs.begin(), available_obs.end(), seed_sort_fn);
// seed_sort_param[track_obs[0]] += seed_sort_param[available_obs.back()];
// std::sort(available_obs.begin(), available_obs.end(), seed_sort_fn);
for (int i = 1; i < track_obs.size(); i++) available_obs.push_back(track_obs[i]);
// for (auto id : track_obs) available_obs.push_back(id);
for (auto id : track_obs) available_obs.push_back(id);
delete track;
}
......@@ -387,9 +394,9 @@ int main(int argc, char **argv) {
genfit::MaterialEffects::getInstance()->setNoEffects();
genfit::MaterialEffects::getInstance()->setEnergyLossBrems(false);
genfit::KalmanFitter fitter;
fitter.setMaxIterations(100);
// fitter.setMaxIterations(100);
int comp = 0;
int n_validate = 5;
int n_validate = 6;
int opt;
while ((opt = getopt(argc, argv, "c:n:h")) != -1) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment