17#include "TStopwatch.h"
82int main(
int argc,
char *argv[]) {
115 if (argc == 1 || argc == 2)
118 MACH3LOG_ERROR(
"./RHat Ntoys MCMCchain_1.root MCMCchain_2.root MCMCchain_3.root ... [how many you like]");
122 Ntoys = atoi(argv[1]);
125 for (
int i = 2; i < argc; i++)
127 MCMCFile.push_back(std::string(argv[i]));
140 MACH3LOG_WARN(
"Gelman is going to be sad :(. He suggested you should use more than one chain (at least 4). Code works fine for one chain, however, estimator might be biased.");
141 MACH3LOG_WARN(
"Multiple chains are more likely to reveal multimodality and poor adaptation or mixing:");
163 auto rnd = std::make_unique<TRandom3>(0);
170 std::vector<int> BurnIn(
Nchains);
171 std::vector<int> nEntries(
Nchains);
172 std::vector<int> nBranches(
Nchains);
173 std::vector<int> step(
Nchains);
185 for (
int m = 0; m <
Nchains; m++)
187 TChain* Chain =
new TChain(
"posteriors");
190 nEntries[m] = int(Chain->GetEntries());
193 BurnIn[m] = nEntries[m]/5;
196 TObjArray* brlis = Chain->GetListOfBranches();
199 nBranches[m] = brlis->GetEntries();
204 Chain->SetBranchStatus(
"*",
false);
208 for (
int i = 0; i < nBranches[m]; i++)
211 TBranch* br =
static_cast<TBranch *
>(brlis->At(i));
216 TString bname = br->GetName();
219 if (bname ==
"step") {
220 Chain->SetBranchStatus(bname,
true);
221 Chain->SetBranchAddress(bname, &step[m]);
224 else if (bname.BeginsWith(
"PCA_") || bname.BeginsWith(
"accProb") || bname.BeginsWith(
"stepTime") )
235 if(bname.BeginsWith(
"LogL"))
244 Chain->SetBranchStatus(bname,
true);
254 if(nBranches[m] != nBranches[0])
256 MACH3LOG_ERROR(
"Ups, something went wrong, chain {} called {} has {} branches, while 0 called {} has {} branches", m,
MCMCFile[m], nBranches[m],
MCMCFile[0], nBranches[0]);
257 MACH3LOG_ERROR(
"All chains should have the same number of branches");
265 for(
int i = 0; i <
Ntoys; i++)
269 for(
int j = 0; j <
nDraw; j++)
278 double* branch_values =
new double[
nDraw]();
279 for (
int j = 0; j <
nDraw; ++j)
281 Chain->SetBranchAddress(
BranchNames[j].Data(), &branch_values[j]);
285 if(BurnIn[m] >= nEntries[m])
287 MACH3LOG_ERROR(
"You are running on a chain shorter than BurnIn cut");
288 MACH3LOG_ERROR(
"Number of entries {} BurnIn cut {}", nEntries[m], BurnIn[m]);
294 for (
int i = 0; i <
Ntoys; i++)
297 int entry = int(nEntries[m]*rnd->Rndm());
299 Chain->GetEntry(entry);
303 if (step[m] < BurnIn[m])
316 for (
int j = 0; j <
nDraw; ++j)
318 Draws[m][i][j] = branch_values[j];
325 delete[] branch_values;
332 #pragma omp parallel for
334 for(
int j = 0; j <
nDraw; j++)
337 std::vector<double> TempDraws(
static_cast<size_t>(
Ntoys) *
Nchains);
338 for(
int m = 0; m <
Nchains; m++)
340 for(
int i = 0; i <
Ntoys; i++)
343 TempDraws[im] =
Draws[m][i][j];
350 #pragma omp parallel for collapse(3)
352 for(
int m = 0; m <
Nchains; m++)
354 for(
int i = 0; i <
Ntoys; i++)
356 for(
int j = 0; j <
nDraw; j++)
363 MACH3LOG_INFO(
"Finished calculating Toys, it took {:.2f}s to finish", clock.RealTime());
394 for (
int m = 0; m <
Nchains; ++m)
401 for (
int j = 0; j <
nDraw; ++j)
452 #pragma omp for collapse(2)
456 for (
int m = 0; m <
Nchains; ++m)
458 for (
int j = 0; j <
nDraw; ++j)
460 for(
int i = 0; i <
Ntoys; i++)
474 for (
int j = 0; j <
nDraw; ++j)
476 for (
int m = 0; m <
Nchains; ++m)
487 #pragma omp for collapse(2)
490 for (
int m = 0; m <
Nchains; ++m)
492 for (
int j = 0; j <
nDraw; ++j)
494 for(
int i = 0; i <
Ntoys; i++)
508 for (
int j = 0; j <
nDraw; ++j)
510 for (
int m = 0; m <
Nchains; ++m)
522 for (
int j = 0; j <
nDraw; ++j)
532 for (
int m = 0; m <
Nchains; ++m)
545 for (
int j = 0; j <
nDraw; ++j)
555 for (
int j = 0; j <
nDraw; ++j)
569 for (
int j = 0; j <
nDraw; ++j)
583 MACH3LOG_INFO(
"Finished calculating RHat, it took {:.2f}s to finish", clock.RealTime());
590 #pragma GCC diagnostic ignored "-Wfloat-conversion"
592 std::string NameTemp =
"";
596 for (
int i = 0; i <
Nchains; i++)
600 while (temp.find(
".root") != std::string::npos) {
601 temp = temp.substr(0, temp.find(
".root"));
604 NameTemp = NameTemp + temp +
"_";
608 NameTemp = std::to_string(
Nchains) +
"Chains" +
"_";
610 NameTemp +=
"diag.root";
612 TFile* DiagFile =
new TFile(NameTemp.c_str(),
"recreate");
616 TH1D *StandardDeviationGlobalPlot =
new TH1D(
"StandardDeviationGlobalPlot",
"StandardDeviationGlobalPlot",
nDraw, 0,
nDraw);
617 TH1D *BetweenChainVariancePlot =
new TH1D(
"BetweenChainVariancePlot",
"BetweenChainVariancePlot",
nDraw, 0,
nDraw);
618 TH1D *MarginalPosteriorVariancePlot =
new TH1D(
"MarginalPosteriorVariancePlot",
"MarginalPosteriorVariancePlot",
nDraw, 0,
nDraw);
619 TH1D *RhatPlot =
new TH1D(
"RhatPlot",
"RhatPlot", 200, 0, 2);
620 TH1D *EffectiveSampleSizePlot =
new TH1D(
"EffectiveSampleSizePlot",
"EffectiveSampleSizePlot", 400, 0, 10000);
622 TH1D *StandardDeviationGlobalFoldedPlot =
new TH1D(
"StandardDeviationGlobalFoldedPlot",
"StandardDeviationGlobalFoldedPlot",
nDraw, 0,
nDraw);
623 TH1D *BetweenChainVarianceFoldedPlot =
new TH1D(
"BetweenChainVarianceFoldedPlot",
"BetweenChainVarianceFoldedPlot",
nDraw, 0,
nDraw);
624 TH1D *MarginalPosteriorVarianceFoldedPlot =
new TH1D(
"MarginalPosteriorVarianceFoldedPlot",
"MarginalPosteriorVarianceFoldedPlot",
nDraw, 0,
nDraw);
625 TH1D *RhatFoldedPlot =
new TH1D(
"RhatFoldedPlot",
"RhatFoldedPlot", 200, 0, 2);
626 TH1D *EffectiveSampleSizeFoldedPlot =
new TH1D(
"EffectiveSampleSizeFoldedPlot",
"EffectiveSampleSizeFoldedPlot", 400, 0, 10000);
628 TH1D *RhatLogPlot =
new TH1D(
"RhatLogPlot",
"RhatLogPlot", 200, 0, 2);
629 TH1D *RhatFoldedLogPlot =
new TH1D(
"RhatFoldedLogPlot",
"RhatFoldedLogPlot", 200, 0, 2);
632 int CiteriumFolded = 0;
633 for(
int j = 0; j <
nDraw; j++)
641 RhatPlot->Fill(
RHat[j]);
643 if(
RHat[j] > 1.1) Criterium++;
655 RhatLogPlot->Fill(
RHat[j]);
660 MACH3LOG_WARN(
"Number of parameters which has R hat greater than 1.1 is {}({:.2f}%) while for R hat folded {}({:.2f}%)", Criterium, 100*
double(Criterium)/
double(
nDraw), CiteriumFolded, 100*
double(CiteriumFolded)/
double(
nDraw));
661 for(
int j = 0; j <
nDraw; j++)
668 StandardDeviationGlobalPlot->Write();
669 BetweenChainVariancePlot->Write();
670 MarginalPosteriorVariancePlot->Write();
672 EffectiveSampleSizePlot->Write();
674 StandardDeviationGlobalFoldedPlot->Write();
675 BetweenChainVarianceFoldedPlot->Write();
676 MarginalPosteriorVarianceFoldedPlot->Write();
677 RhatFoldedPlot->Write();
678 EffectiveSampleSizeFoldedPlot->Write();
680 RhatLogPlot->Write();
681 RhatFoldedLogPlot->Write();
684 auto TempCanvas = std::make_unique<TCanvas>(
"Canvas",
"Canvas", 1024, 1024);
685 gStyle->SetOptStat(0);
686 TempCanvas->SetGridx();
687 TempCanvas->SetGridy();
690 auto TempLine = std::make_unique<TLine>(0, 0, 0, 0);
691 TempLine->SetLineColor(kBlack);
693 RhatPlot->GetXaxis()->SetTitle(
"R hat");
694 RhatPlot->SetLineColor(kRed);
695 RhatPlot->SetFillColor(kRed);
696 RhatFoldedPlot->SetLineColor(kBlue);
697 RhatFoldedPlot->SetFillColor(kBlue);
699 TLegend *Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
700 Legend->SetTextSize(0.04);
701 Legend->SetFillColor(0);
702 Legend->SetFillStyle(0);
703 Legend->SetLineWidth(0);
704 Legend->SetLineColor(0);
706 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
Ntoys,
Nchains),
"");
707 Legend->AddEntry(RhatPlot,
"Rhat Gelman 2013",
"l");
708 Legend->AddEntry(RhatFoldedPlot,
"Rhat-Folded Gelman 2021",
"l");
711 RhatFoldedPlot->Draw(
"same");
712 Legend->Draw(
"same");
713 TempCanvas->Write(
"Rhat");
718 RhatLogPlot->GetXaxis()->SetTitle(
"R hat for LogL");
719 RhatLogPlot->SetLineColor(kRed);
720 RhatLogPlot->SetFillColor(kRed);
721 RhatFoldedLogPlot->SetLineColor(kBlue);
722 RhatFoldedLogPlot->SetFillColor(kBlue);
724 Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
725 Legend->SetTextSize(0.04);
726 Legend->SetFillColor(0);
727 Legend->SetFillStyle(0);
728 Legend->SetLineWidth(0);
729 Legend->SetLineColor(0);
731 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
Ntoys,
Nchains),
"");
732 Legend->AddEntry(RhatLogPlot,
"Rhat Gelman 2013",
"l");
733 Legend->AddEntry(RhatFoldedLogPlot,
"Rhat-Folded Gelman 2021",
"l");
736 RhatFoldedLogPlot->Draw(
"same");
737 Legend->Draw(
"same");
738 TempCanvas->Write(
"RhatLog");
743 EffectiveSampleSizePlot->GetXaxis()->SetTitle(
"S_{eff, BDA2}");
744 EffectiveSampleSizePlot->SetLineColor(kRed);
745 EffectiveSampleSizeFoldedPlot->SetLineColor(kBlue);
747 Legend =
new TLegend(0.45, 0.6, 0.9, 0.9);
748 Legend->SetTextSize(0.03);
749 Legend->SetFillColor(0);
750 Legend->SetFillStyle(0);
751 Legend->SetLineWidth(0);
752 Legend->SetLineColor(0);
754 const double Mean1 = EffectiveSampleSizePlot->GetMean();
755 const double RMS1 = EffectiveSampleSizePlot->GetRMS();
756 const double Mean2 = EffectiveSampleSizeFoldedPlot->GetMean();
757 const double RMS2 = EffectiveSampleSizeFoldedPlot->GetRMS();
759 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
Ntoys,
Nchains),
"");
760 Legend->AddEntry(EffectiveSampleSizePlot, Form(
"S_{eff, BDA2} #mu = %.2f, #sigma = %.2f",Mean1 ,RMS1),
"l");
761 Legend->AddEntry(EffectiveSampleSizeFoldedPlot, Form(
"S_{eff, BDA2} Folded, #mu = %.2f, #sigma = %.2f",Mean2 ,RMS2),
"l");
763 EffectiveSampleSizePlot->Draw();
764 EffectiveSampleSizeFoldedPlot->Draw(
"same");
765 Legend->Draw(
"same");
766 TempCanvas->Write(
"EffectiveSampleSize");
769 delete StandardDeviationGlobalPlot;
770 delete BetweenChainVariancePlot;
771 delete MarginalPosteriorVariancePlot;
773 delete EffectiveSampleSizePlot;
775 delete StandardDeviationGlobalFoldedPlot;
776 delete BetweenChainVarianceFoldedPlot;
777 delete MarginalPosteriorVarianceFoldedPlot;
778 delete RhatFoldedPlot;
779 delete EffectiveSampleSizeFoldedPlot;
784 delete RhatFoldedLogPlot;
812 for(
int m = 0; m <
Nchains; m++)
814 for(
int i = 0; i <
Ntoys; i++)
816 delete[]
Draws[m][i];
841 std::sort(arr, arr+
size);
844 return (arr[(
size-1)/2] + arr[
size/2])/2.0;
851 if(std::isnan(var) || !std::isfinite(var)) var = cap;
#define _MaCh3_Safe_Include_Start_
KS: Avoiding warning checking for headers.
#define _MaCh3_Safe_Include_End_
KS: Restore warning checking after including external headers.
#define MACH3LOG_CRITICAL
void SetMaCh3LoggerFormat()
Set messaging format of the logger.
int main(int argc, char *argv[])
double * StandardDeviationGlobalFolded
void CapVariable(double var, double cap)
double * EffectiveSampleSizeFolded
double * BetweenChainVarianceFolded
std::vector< bool > ValidPar
double * BetweenChainVariance
double * MeanGlobalFolded
double ** StandardDeviation
double * StandardDeviationGlobal
double * MarginalPosteriorVariance
double ** StandardDeviationFolded
std::vector< TString > BranchNames
std::vector< std::string > MCMCFile
double CalcMedian(double arr[], int size)
double * EffectiveSampleSize
double * MarginalPosteriorVarianceFolded
Custom exception class for MaCh3 errors.
void PrintProgressBar(const Long64_t Done, const Long64_t All)
KS: Simply print progress bar.
void MaCh3Welcome()
KS: Prints welcome message with MaCh3 logo.