Commit 89d8ce4d authored by David Blyth's avatar David Blyth

Saving state

parent 8d2388dc
...@@ -35,6 +35,17 @@ void printBFields() { ...@@ -35,6 +35,17 @@ void printBFields() {
} }
} }
genfit::AbsMeasurement *makeIPMeasurement() {
TVectorD hitCoords(3);
hitCoords = 0;
TMatrixDSym hitCov(3);
hitCov = 0;
hitCov[0][0] = 10;
hitCov[1][1] = 10;
hitCov[2][2] = 10;
return new genfit::SpacepointMeasurement(hitCoords, hitCov, 0, 0, nullptr);
}
genfit::AbsMeasurement *makeMeasurement(eic::EnergyDep *entry) { genfit::AbsMeasurement *makeMeasurement(eic::EnergyDep *entry) {
if (entry->mean() / entry->noise() < 2) return NULL; if (entry->mean() / entry->noise() < 2) return NULL;
...@@ -143,13 +154,14 @@ void checkMetadata(proio::Event *event) { ...@@ -143,13 +154,14 @@ void checkMetadata(proio::Event *event) {
} }
} }
void quickTrackProcess(genfit::KalmanFitter *fitter, genfit::Track *track) { void quickTrackProcess(genfit::KalmanFitter *fitter, genfit::Track *track, int n_passes = 1) {
for (auto rep : track->getTrackReps()) { for (auto rep : track->getTrackReps()) {
fitter->processTrackPartially(track, rep, 0, -1); for (int i = 0; i < n_passes; i++) {
track->reverseTrack(); track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1); fitter->processTrackPartially(track, rep, 0, -1);
track->reverseTrack(); track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1); fitter->processTrackPartially(track, rep, 0, -1);
}
} }
} }
...@@ -207,84 +219,97 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter ...@@ -207,84 +219,97 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter
auto measurement = measurements[track_obs[0]]; auto measurement = measurements[track_obs[0]];
TVector3 ip_pos(0, 0, 0); TVector3 ip_pos(0, 0, 0);
auto meas_coords = measurement->getRawHitCoords(); auto meas_coords = measurement->getRawHitCoords();
TVector3 meas_pos(0.01 * meas_coords[0], 0.01 * meas_coords[1], 0.01 * meas_coords[2]); TVector3 meas_pos(10 * meas_coords[0], 10 * meas_coords[1], 10 * meas_coords[2]);
auto kernel_sort_fn = [&](const uint64_t a_id, const uint64_t b_id) { auto kernel_sort_fn = [&](const uint64_t a_id, const uint64_t b_id) {
auto a_coords = measurements[a_id]->getRawHitCoords(); auto a_coords = measurements[a_id]->getRawHitCoords();
TVector3 a_pos(a_coords[0], a_coords[1], a_coords[2]); TVector3 a_pos(a_coords[0], a_coords[1], a_coords[2]);
auto b_coords = measurements[b_id]->getRawHitCoords(); auto b_coords = measurements[b_id]->getRawHitCoords();
TVector3 b_pos(b_coords[0], b_coords[1], b_coords[2]); TVector3 b_pos(b_coords[0], b_coords[1], b_coords[2]);
return meas_pos.Dot(a_pos) / a_pos.Mag() > meas_pos.Dot(b_pos) / b_pos.Mag(); return meas_pos.Dot(a_pos) / a_pos.Mag() < meas_pos.Dot(b_pos) / b_pos.Mag();
}; };
std::sort(available_obs.begin(), available_obs.end(), kernel_sort_fn); std::sort(available_obs.begin(), available_obs.end(), kernel_sort_fn);
auto track = new genfit::Track; auto track = new genfit::Track;
track->setStateSeed(ip_pos, meas_pos); track->setStateSeed(ip_pos, meas_pos);
track->insertMeasurement(makeIPMeasurement());
track->getPoint(0)->setSortingParameter(0.0);
track->addTrackRep(new genfit::RKTrackRep(11)); track->addTrackRep(new genfit::RKTrackRep(11));
track->addTrackRep(new genfit::RKTrackRep(13)); track->addTrackRep(new genfit::RKTrackRep(13));
int n_seed = 5; int n_seed = 3;
for (int i = 1; i < n_seed; i++) track_obs.push_back(available_obs[i - 1]); for (int i = 0; i < n_seed - 1; i++) {
track_obs.push_back(available_obs.back());
available_obs.pop_back();
}
std::map<genfit::TrackPoint *, uint64_t> id_lookup;
for (auto id : track_obs) { for (auto id : track_obs) {
track->insertMeasurement(measurements[id]->clone()); track->insertMeasurement(measurements[id]->clone());
meas_coords = measurements[id]->getRawHitCoords(); meas_coords = measurements[id]->getRawHitCoords();
std::cout << meas_coords[0] << "\t" << meas_coords[1] << "\t" << meas_coords[2] << std::endl; auto track_point = track->getPoint(-1);
track->getPoint(-1)->setSortingParameter(seed_sort_param[id]); track_point->setSortingParameter(seed_sort_param[id]);
id_lookup[track_point] = id;
} }
available_obs.erase(available_obs.begin(), available_obs.begin() + n_seed - 1);
track->sort(); track->sort();
quickTrackProcess(fitter, track); fitter->processTrack(track);
bool do_extend = false;
for (auto rep : track->getTrackReps())
if (track->getFitStatus(rep)->isFitConverged()) {
do_extend = true;
break;
}
for (auto iter = available_obs.begin(); iter != available_obs.end();) { // extend track
auto measurement = measurements[*iter]; while (do_extend && available_obs.size() > 0) {
meas_coords = track->getPoint(-1)->getRawMeasurement()->getRawHitCoords();
meas_pos = TVector3(0.01 * meas_coords[0], 0.01 * meas_coords[1], 0.01 * meas_coords[2]);
std::sort(available_obs.begin(), available_obs.end(), kernel_sort_fn);
auto id = available_obs.back();
track_obs.push_back(id);
available_obs.pop_back();
auto measurement = measurements[id];
auto coords = measurement->getRawHitCoords(); auto coords = measurement->getRawHitCoords();
std::cout << coords[0] << "\t" << coords[1] << "\t" << coords[2] << std::endl; track->insertMeasurement(measurement->clone());
track->insertMeasurement(measurements[*iter]->clone());
auto track_point = track->getPoint(-1); auto track_point = track->getPoint(-1);
track_point->setSortingParameter(seed_sort_param[*iter]); id_lookup[track_point] = id;
track_point->setSortingParameter(seed_sort_param[id]);
track->sort(); track->sort();
track_obs.push_back(*iter);
iter = available_obs.erase(iter);
quickTrackProcess(fitter, track);
std::vector<genfit::AbsTrackRep *> all_reps(track->getTrackReps()); std::vector<genfit::AbsTrackRep *> all_reps(track->getTrackReps());
std::vector<genfit::AbsTrackRep *> bad_reps; std::vector<genfit::AbsTrackRep *> bad_reps;
for (auto rep : all_reps) { for (auto rep : all_reps) {
fitter->processTrackPartially(track, rep, -2, -1);
double pval = forwardPVal(track, rep); double pval = forwardPVal(track, rep);
std::cout << "pval for pdg code " << rep->getPDG() << ": " << pval << std::endl;
if (fabs(pval - 0.5) < 0.49) continue; if (fabs(pval - 0.5) < 0.49) continue;
bad_reps.push_back(rep); bad_reps.push_back(rep);
} }
if (bad_reps.size() == all_reps.size()) { if (bad_reps.size() == all_reps.size()) {
std::cout << "breaking extension" << std::endl; for (int i = 1; i < track->getNumPoints(); i++) {
for (int i = 0; i < track->getNumPoints(); i++) {
if (track_point == track->getPoint(i)) { if (track_point == track->getPoint(i)) {
track->deletePoint(i); track->deletePoint(i);
break; break;
} }
} }
available_obs.push_back(*iter); available_obs.push_back(id);
track_obs.pop_back(); track_obs.pop_back();
break; break;
} }
for (auto rep : bad_reps) track->deleteTrackRep(track->getIdForRep(rep)); for (auto rep : bad_reps) track->deleteTrackRep(track->getIdForRep(rep));
quickTrackProcess(fitter, track);
} }
track->determineCardinalRep(); track->determineCardinalRep();
auto rep = track->getCardinalRep();
// fit and prune track // fit track
std::cout << "track_obs.size(): " << track_obs.size() << std::endl; if (track_obs.size() >= n_validate) {
while (track_obs.size() >= n_validate) {
auto rep = track->getCardinalRep();
fitter->processTrackWithRep(track, rep); fitter->processTrackWithRep(track, rep);
if (track->getFitStatus(rep)->isFitConverged()) { if (track->getFitStatus(rep)->isFitConverged()) {
auto eic_track = new eic::Track; auto eic_track = new eic::Track;
try { auto fieldMgr = genfit::FieldManager::getInstance();
std::cout << "successful fit" << std::endl; for (unsigned int id = 0; id < track->getNumPoints() - 1; id++) {
auto fieldMgr = genfit::FieldManager::getInstance(); try {
for (unsigned int id = 0; id < track->getNumPoints() - 1; id++) { auto point = track->getPoint(id);
auto point = track->getPointWithMeasurement(id); auto info = static_cast<genfit::KalmanFitterInfo *>(point->getFitterInfo(rep));
auto info = point->getFitterInfo(rep);
auto state = info->getFittedState(); auto state = info->getFittedState();
TVector3 pos = rep->getPos(state); TVector3 pos = rep->getPos(state);
TVector3 mom = rep->getMom(state); TVector3 mom = rep->getMom(state);
...@@ -310,57 +335,39 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter ...@@ -310,57 +335,39 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter
magfield->set_z(field.z() / 10); magfield->set_z(field.z() / 10);
segment->set_chargesign(charge); segment->set_chargesign(charge);
segment->set_length(length * 10); segment->set_length(length * 10);
} catch (genfit::Exception e) {
std::cerr << e.what() << std::endl;
} }
auto entryID = event->AddEntry(eic_track, "Tracker");
event->TagEntry(entryID, "Reconstructed");
event->TagEntry(entryID, "Vis");
} catch (const genfit::Exception e) {
std::cout << e.what() << std::endl;
} }
auto entryID = event->AddEntry(eic_track, "Tracker");
event->TagEntry(entryID, "Reconstructed");
event->TagEntry(entryID, "Vis");
track_obs.clear(); track_obs.clear();
break; }
} else { }
std::cout << "failed to converge" << std::endl;
double max_chi2_inc = 0;
int max_chi2_inc_i = 0;
for (int i = 0; i < track_obs.size(); i++) {
auto fi =
static_cast<genfit::KalmanFitterInfo *>(track->getPoint(i)->getKalmanFitterInfo(rep));
if (!fi) {
max_chi2_inc_i = i;
break;
}
auto update = fi->getUpdate(1);
if (!update) {
max_chi2_inc_i = i;
break;
}
auto chi2_inc = update->getChiSquareIncrement();
// auto chi2_inc = fi->getSmoothedChi2();
// if (chi2_inc > max_chi2_inc) {
// max_chi2_inc = chi2_inc;
// max_chi2_inc_i = i;
//}
}
if (max_chi2_inc_i == 0) {
std::cout << "breaking fit" << std::endl;
break;
}
track_obs.erase(track_obs.begin() + max_chi2_inc_i); // remove largest outlying hit from track
track->deletePoint(max_chi2_inc_i); double max_chi2_inc = 0;
int max_chi2_inc_i = 1;
for (int i = 1; i < track_obs.size() + 1; i++) {
auto fi = static_cast<genfit::KalmanFitterInfo *>(track->getPoint(i)->getKalmanFitterInfo(rep));
if (!fi) {
max_chi2_inc_i = i;
break;
}
auto update = fi->getUpdate(1);
if (!update) {
max_chi2_inc_i = i;
break;
} }
auto chi2_inc = update->getChiSquareIncrement();
} }
auto max_chi2_inc_id = id_lookup[track->getPoint(max_chi2_inc_i)];
track_obs.erase(std::remove(track_obs.begin(), track_obs.end(), max_chi2_inc_id), track_obs.end());
// return unconsumed hits to available list // return unconsumed hits to available list
// std::sort(available_obs.begin(), available_obs.end(), seed_sort_fn); for (auto id : track_obs) available_obs.push_back(id);
// seed_sort_param[track_obs[0]] += seed_sort_param[available_obs.back()];
// std::sort(available_obs.begin(), available_obs.end(), seed_sort_fn);
for (int i = 1; i < track_obs.size(); i++) available_obs.push_back(track_obs[i]);
// for (auto id : track_obs) available_obs.push_back(id);
delete track; delete track;
} }
...@@ -387,9 +394,9 @@ int main(int argc, char **argv) { ...@@ -387,9 +394,9 @@ int main(int argc, char **argv) {
genfit::MaterialEffects::getInstance()->setNoEffects(); genfit::MaterialEffects::getInstance()->setNoEffects();
genfit::MaterialEffects::getInstance()->setEnergyLossBrems(false); genfit::MaterialEffects::getInstance()->setEnergyLossBrems(false);
genfit::KalmanFitter fitter; genfit::KalmanFitter fitter;
fitter.setMaxIterations(100); // fitter.setMaxIterations(100);
int comp = 0; int comp = 0;
int n_validate = 5; int n_validate = 6;
int opt; int opt;
while ((opt = getopt(argc, argv, "c:n:h")) != -1) { while ((opt = getopt(argc, argv, "c:n:h")) != -1) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment