18 #include "TStopwatch.h"
88 int main(
int argc,
char *argv[]) {
110 MACH3LOG_ERROR(
"./RHat NThin MCMCchain_1.root MCMCchain_2.root MCMCchain_3.root ... [how many you like]");
114 NThin = atoi(argv[1]);
117 for (
int i = 2; i < argc; i++)
119 MCMCFile.push_back(std::string(argv[i]));
126 MACH3LOG_WARN(
"Gelman is going to be sad :(. He suggested you should use more than one chain (at least 4). Code works fine for one chain, however, estimator might be biased.");
127 MACH3LOG_WARN(
"Multiple chains are more likely to reveal multimodality and poor adaptation or mixing:");
156 std::vector<unsigned int> BurnIn(
Nchains);
157 std::vector<unsigned int> nEntries(
Nchains);
158 std::vector<int> nBranches(
Nchains);
159 std::vector<unsigned int> step(
Nchains);
171 for (
int m = 0; m <
Nchains; m++)
173 TChain* Chain =
new TChain(
"posteriors");
176 nEntries[m] =
static_cast<unsigned int>(Chain->GetEntries());
184 BurnIn[m] = nEntries[m]/5;
187 TObjArray* brlis = Chain->GetListOfBranches();
190 nBranches[m] = brlis->GetEntries();
195 Chain->SetBranchStatus(
"*",
false);
199 for (
int i = 0; i < nBranches[m]; i++)
202 TBranch* br =
static_cast<TBranch *
>(brlis->At(i));
207 TString bname = br->GetName();
210 if (bname ==
"step") {
211 Chain->SetBranchStatus(bname,
true);
212 Chain->SetBranchAddress(bname, &step[m]);
215 else if (bname.BeginsWith(
"PCA_") || bname.BeginsWith(
"accProb") || bname.BeginsWith(
"stepTime") )
226 if(bname.BeginsWith(
"LogL"))
235 Chain->SetBranchStatus(bname,
true);
250 for (
int id = 0;
id <
nDraw; ++id)
264 if(nBranches[m] != nBranches[0])
266 MACH3LOG_ERROR(
"Ups, something went wrong, chain {} called {} has {} branches, while 0 called {} has {} branches", m,
MCMCFile[m], nBranches[m],
MCMCFile[0], nBranches[0]);
267 MACH3LOG_ERROR(
"All chains should have the same number of branches");
274 double* branch_values =
new double[
nDraw]();
275 for (
int id = 0;
id <
nDraw; ++id)
277 Chain->SetBranchAddress(
BranchNames[
id].Data(), &branch_values[id]);
281 if(BurnIn[m] >= nEntries[m])
283 MACH3LOG_ERROR(
"You are running on a chain shorter than BurnIn cut");
284 MACH3LOG_ERROR(
"Number of entries {} BurnIn cut {}", nEntries[m], BurnIn[m]);
296 Chain->GetEntry(entry);
300 if (step[m] < BurnIn[m])
314 for (
int j = 0; j <
nDraw; ++j)
317 S2_global[j] += branch_values[j]*branch_values[j];
319 S2_chain[m][j] += branch_values[j]*branch_values[j];
329 delete[] branch_values;
334 MACH3LOG_INFO(
"Finished calculating Toys, it took {:.2f}s to finish", clock.RealTime());
354 for (
int m = 0; m <
Nchains; ++m)
359 for (
int j = 0; j <
nDraw; ++j)
406 for (
int m = 0; m <
Nchains; ++m)
408 for (
int j = 0; j <
nDraw; ++j)
419 for (
int j = 0; j <
nDraw; ++j)
421 for (
int m = 0; m <
Nchains; ++m)
432 for (
int j = 0; j <
nDraw; ++j)
441 for (
int m = 0; m <
Nchains; ++m)
453 for (
int j = 0; j <
nDraw; ++j)
462 for (
int j = 0; j <
nDraw; ++j)
474 for (
int j = 0; j <
nDraw; ++j)
486 MACH3LOG_INFO(
"Finished calculating RHat, it took {:.2f}s to finish", clock.RealTime());
493 #pragma GCC diagnostic ignored "-Wfloat-conversion"
495 std::string NameTemp =
"";
499 for (
int i = 0; i <
Nchains; i++)
503 while (temp.find(
".root") != std::string::npos) {
504 temp = temp.substr(0, temp.find(
".root"));
507 const auto slash = temp.find_last_of(
"/\\");
508 if (slash != std::string::npos) {
509 temp = temp.substr(slash + 1);
512 NameTemp = NameTemp + temp +
"_";
516 NameTemp = std::to_string(
Nchains) +
"Chains" +
"_";
518 NameTemp +=
"diag.root";
520 TFile *DiagFile =
M3::Open(NameTemp,
"recreate", __FILE__, __LINE__);
523 TH1D *StandardDeviationGlobalPlot =
new TH1D(
"StandardDeviationGlobalPlot",
"StandardDeviationGlobalPlot",
nDraw, 0,
nDraw);
524 TH1D *BetweenChainVariancePlot =
new TH1D(
"BetweenChainVariancePlot",
"BetweenChainVariancePlot",
nDraw, 0,
nDraw);
525 TH1D *MarginalPosteriorVariancePlot =
new TH1D(
"MarginalPosteriorVariancePlot",
"MarginalPosteriorVariancePlot",
nDraw, 0,
nDraw);
526 TH1D *RhatPlot =
new TH1D(
"RhatPlot",
"RhatPlot", 200, 0, 2);
527 TH1D *EffectiveSampleSizePlot =
new TH1D(
"EffectiveSampleSizePlot",
"EffectiveSampleSizePlot", 400, 0, 10000);
529 TH1D *RhatLogPlot =
new TH1D(
"RhatLogPlot",
"RhatLogPlot", 200, 0, 2);
532 for(
int j = 0; j <
nDraw; j++)
540 RhatPlot->Fill(
RHat[j]);
542 if(
RHat[j] > 1.1) Criterium++;
546 RhatLogPlot->Fill(
RHat[j]);
550 MACH3LOG_WARN(
"Number of parameters which has R hat greater than 1.1 is {}({:.2f}%)", Criterium, 100*
double(Criterium)/
double(
nDraw));
551 for(
int j = 0; j <
nDraw; j++)
558 StandardDeviationGlobalPlot->Write();
559 BetweenChainVariancePlot->Write();
560 MarginalPosteriorVariancePlot->Write();
562 EffectiveSampleSizePlot->Write();
564 RhatLogPlot->Write();
567 auto TempCanvas = std::make_unique<TCanvas>(
"Canvas",
"Canvas", 1024, 1024);
568 gStyle->SetOptStat(0);
569 TempCanvas->SetGridx();
570 TempCanvas->SetGridy();
573 auto TempLine = std::make_unique<TLine>(0, 0, 0, 0);
574 TempLine->SetLineColor(kBlack);
576 RhatPlot->GetXaxis()->SetTitle(
"R hat");
577 RhatPlot->SetLineColor(kRed);
578 RhatPlot->SetFillColor(kRed);
580 TLegend *Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
581 Legend->SetTextSize(0.04);
582 Legend->SetFillColor(0);
583 Legend->SetFillStyle(0);
584 Legend->SetLineWidth(0);
585 Legend->SetLineColor(0);
587 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
TotToys,
Nchains),
"");
588 Legend->AddEntry(RhatPlot,
"Rhat Gelman 2013",
"l");
591 Legend->Draw(
"same");
592 TempCanvas->Write(
"Rhat");
597 RhatLogPlot->GetXaxis()->SetTitle(
"R hat for LogL");
598 RhatLogPlot->SetLineColor(kRed);
599 RhatLogPlot->SetFillColor(kRed);
601 Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
602 Legend->SetTextSize(0.04);
603 Legend->SetFillColor(0);
604 Legend->SetFillStyle(0);
605 Legend->SetLineWidth(0);
606 Legend->SetLineColor(0);
608 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
TotToys,
Nchains),
"");
609 Legend->AddEntry(RhatLogPlot,
"Rhat Gelman 2013",
"l");
612 Legend->Draw(
"same");
613 TempCanvas->Write(
"RhatLog");
618 EffectiveSampleSizePlot->GetXaxis()->SetTitle(
"S_{eff, BDA2}");
619 EffectiveSampleSizePlot->SetLineColor(kRed);
621 Legend =
new TLegend(0.45, 0.6, 0.9, 0.9);
622 Legend->SetTextSize(0.03);
623 Legend->SetFillColor(0);
624 Legend->SetFillStyle(0);
625 Legend->SetLineWidth(0);
626 Legend->SetLineColor(0);
628 const double Mean1 = EffectiveSampleSizePlot->GetMean();
629 const double RMS1 = EffectiveSampleSizePlot->GetRMS();
631 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
TotToys,
Nchains),
"");
632 Legend->AddEntry(EffectiveSampleSizePlot, Form(
"S_{eff, BDA2} #mu = %.2f, #sigma = %.2f",Mean1 ,RMS1),
"l");
634 EffectiveSampleSizePlot->Draw();
635 Legend->Draw(
"same");
636 TempCanvas->Write(
"EffectiveSampleSize");
639 delete StandardDeviationGlobalPlot;
640 delete BetweenChainVariancePlot;
641 delete MarginalPosteriorVariancePlot;
643 delete EffectiveSampleSizePlot;
668 for(
int m = 0; m <
Nchains; m++)
690 std::sort(arr, arr+size);
693 return (arr[(size-1)/2] + arr[size/2])/2.0;
700 if(std::isnan(var) || !std::isfinite(var)) var = cap;
#define _MaCh3_Safe_Include_Start_
KS: Avoiding warning checking for headers.
#define _MaCh3_Safe_Include_End_
KS: Restore warning checking after including external headers.
#define MACH3LOG_CRITICAL
void SetMaCh3LoggerFormat()
Set messaging format of the logger.
int main(int argc, char *argv[])
void CapVariable(double var, double cap)
std::vector< bool > ValidPar
double * BetweenChainVariance
double ** StandardDeviation
double * StandardDeviationGlobal
double * MarginalPosteriorVariance
std::vector< TString > BranchNames
std::vector< std::string > MCMCFile
double CalcMedian(double arr[], int size)
double * EffectiveSampleSize
Custom exception class used throughout MaCh3.
TFile * Open(const std::string &Name, const std::string &Type, const std::string &File, const int Line)
Opens a ROOT file with the given name and mode.
void PrintProgressBar(const Long64_t Done, const Long64_t All)
KS: Simply print progress bar.
void MaCh3Welcome()
KS: Prints welcome message with MaCh3 logo.