Commit 8d2388dc authored by David Blyth's avatar David Blyth

Saving a pretty good state

parent a034170d
......@@ -143,6 +143,30 @@ void checkMetadata(proio::Event *event) {
}
}
void quickTrackProcess(genfit::KalmanFitter *fitter, genfit::Track *track) {
for (auto rep : track->getTrackReps()) {
fitter->processTrackPartially(track, rep, 0, -1);
track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1);
track->reverseTrack();
fitter->processTrackPartially(track, rep, 0, -1);
}
}
double forwardPVal(genfit::Track *track, genfit::AbsTrackRep *rep) {
double chi2 = 0;
double ndf = -1. * rep->getDim();
for (auto point : track->getPointsWithMeasurement()) {
auto fi = point->getKalmanFitterInfo(rep);
if (!fi) continue;
auto update = fi->getForwardUpdate();
if (!update) continue;
chi2 += update->getChiSquareIncrement();
ndf += update->getNdf();
}
return std::max(0., ROOT::Math::chisquared_cdf_c(chi2, ndf));
}
void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter *fitter) {
// build available observations list and a map of observations to
// GenFit measurements
......@@ -156,21 +180,20 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter
if (measurement) {
available_obs.push_back(id);
measurements[id] = measurement;
// double t_sum = 0;
// double weight_sum = 0;
// for (auto pos : obs->pos()) {
// double weight = pos.weightmod() + 1;
// t_sum += pos.mean().t() * weight;
// weight_sum += weight;
//}
// seed_sort_param[id] = t_sum / weight_sum;
seed_sort_param[id] =
static_cast<eic::SimHit *>(event->GetEntry(obs->source(0)))->globalprepos().t();
double t_sum = 0;
double weight_sum = 0;
for (auto pos : obs->pos()) {
double weight = pos.weightmod() + 1;
t_sum += pos.mean().t() * weight;
weight_sum += weight;
}
seed_sort_param[id] = t_sum / weight_sum;
// seed_sort_param[id] =
// static_cast<eic::SimHit *>(event->GetEntry(obs->source(0)))->globalprepos().t();
// seed_sort_param[id] = measurement->getRawHitCoords().Norm2Sqr();
}
}
}
auto seed_sort_fn = [&](const uint64_t a_id, const uint64_t b_id) {
return seed_sort_param[a_id] > seed_sort_param[b_id];
};
......@@ -208,90 +231,43 @@ void trackEvent(proio::Event *event, const int n_validate, genfit::KalmanFitter
}
available_obs.erase(available_obs.begin(), available_obs.begin() + n_seed - 1);
track->sort();
for (auto rep : track->getTrackReps()) fitter->processTrackPartially(track, rep, 0, -1);
quickTrackProcess(fitter, track);
for (auto iter = available_obs.begin(); iter != available_obs.end();) {
auto tmp_track = new genfit::Track;
tmp_track->setStateSeed(ip_pos, meas_pos);
tmp_track->addTrackRep(new genfit::RKTrackRep(11));
tmp_track->addTrackRep(new genfit::RKTrackRep(13));
for (auto id : track_obs) {
tmp_track->insertMeasurement(measurements[id]->clone());
tmp_track->getPoint(-1)->setSortingParameter(seed_sort_param[id]);
}
auto measurement = measurements[*iter];
auto coords = measurement->getRawHitCoords();
std::cout << coords[0] << "\t" << coords[1] << "\t" << coords[2] << std::endl;
tmp_track->insertMeasurement(measurements[*iter]->clone());
auto track_point = tmp_track->getPoint(-1);
track->insertMeasurement(measurements[*iter]->clone());
auto track_point = track->getPoint(-1);
track_point->setSortingParameter(seed_sort_param[*iter]);
tmp_track->sort();
track->sort();
track_obs.push_back(*iter);
iter = available_obs.erase(iter);
// process and check for goodness
std::vector<genfit::AbsTrackRep *> all_reps(tmp_track->getTrackReps());
quickTrackProcess(fitter, track);
std::vector<genfit::AbsTrackRep *> all_reps(track->getTrackReps());
std::vector<genfit::AbsTrackRep *> bad_reps;
if (tmp_track->getNumPoints() >= n_validate) {
for (auto rep : all_reps) {
fitter->processTrackPartially(tmp_track, rep, 0, -1);
tmp_track->reverseTrack();
fitter->processTrackPartially(tmp_track, rep, 0, -1);
tmp_track->reverseTrack();
fitter->processTrackPartially(tmp_track, rep, 0, -1);
// auto fi = static_cast<genfit::KalmanFitterInfo
// *>(track_point->getKalmanFitterInfo(rep)); if (!fi) {
// bad_reps.push_back(rep);
// continue;
//}
// auto update = fi->getUpdate(1);
// auto chi2_inc = update->getChiSquareIncrement() / update->getNdf();
// std::cout << "chi2_inc for pdg code " << rep->getPDG() << ": " << chi2_inc <<
// std::endl; if (chi2_inc > 5) {
// bad_reps.push_back(rep);
// continue;
//}
// try {
// auto pval = fitter->getPVal(tmp_track, rep, 1);
// std::cout << "pval for pdg code " << rep->getPDG() << ": " << pval << std::endl;
// if (fabs(pval - 0.5) < .49) continue;
//} catch (genfit::Exception e) {
// ;
//}
double chi2 = 0;
double ndf = -1. * rep->getDim();
auto points = tmp_track->getPointsWithMeasurement();
for (auto iter = points.begin() + 0; iter != points.end(); iter++) {
auto fi = (*iter)->getKalmanFitterInfo(rep);
if (!fi) continue;
auto update = fi->getForwardUpdate();
chi2 += update->getChiSquareIncrement();
std::cout << "chi2 inc: " << update->getChiSquareIncrement() << std::endl;
ndf += update->getNdf();
for (auto rep : all_reps) {
double pval = forwardPVal(track, rep);
std::cout << "pval for pdg code " << rep->getPDG() << ": " << pval << std::endl;
if (fabs(pval - 0.5) < 0.49) continue;
bad_reps.push_back(rep);
}
if (bad_reps.size() == all_reps.size()) {
std::cout << "breaking extension" << std::endl;
for (int i = 0; i < track->getNumPoints(); i++) {
if (track_point == track->getPoint(i)) {
track->deletePoint(i);
break;
}
std::cout << "chi2: " << chi2 << ", ndf: " << ndf << std::endl;
double pval = std::max(0., ROOT::Math::chisquared_cdf_c(chi2, ndf));
std::cout << "pval for pdg code " << rep->getPDG() << ": " << pval << std::endl;
if (fabs(pval - 0.5) < 0.49) continue;
bad_reps.push_back(rep);
// auto len = tmp_track->getTrackLen(rep, -2, -1);
// if (len < 0) {
// bad_reps.push_back(rep);
// continue;
//}
}
if (bad_reps.size() == all_reps.size()) {
iter++;
std::cout << "skipping hit" << std::endl;
break;
}
available_obs.push_back(*iter);
track_obs.pop_back();
break;
}
// clear bad track reps and pop hit from available list
for (auto rep : bad_reps) tmp_track->deleteTrackRep(tmp_track->getIdForRep(rep));
track_obs.push_back(*iter);
delete track;
track = tmp_track;
iter = available_obs.erase(iter);
for (auto rep : bad_reps) track->deleteTrackRep(track->getIdForRep(rep));
}
track->determineCardinalRep();
......@@ -465,7 +441,7 @@ int main(int argc, char **argv) {
trackEvent(&event, n_validate, &fitter);
writer->Push(&event);
n_events++;
if (n_events == 2) break;
if (n_events == 10) break;
}
delete writer;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment