18 #include "TStopwatch.h"
88 int main(
int argc,
char *argv[]) {
121 if (argc == 1 || argc == 2)
124 MACH3LOG_ERROR(
"./RHat Ntoys MCMCchain_1.root MCMCchain_2.root MCMCchain_3.root ... [how many you like]");
128 Ntoys = atoi(argv[1]);
131 for (
int i = 2; i < argc; i++)
133 MCMCFile.push_back(std::string(argv[i]));
146 MACH3LOG_WARN(
"Gelman is going to be sad :(. He suggested you should use more than one chain (at least 4). Code works fine for one chain, however, estimator might be biased.");
147 MACH3LOG_WARN(
"Multiple chains are more likely to reveal multimodality and poor adaptation or mixing:");
169 auto rnd = std::make_unique<TRandom3>(0);
176 std::vector<unsigned int> BurnIn(
Nchains);
177 std::vector<unsigned int> nEntries(
Nchains);
178 std::vector<int> nBranches(
Nchains);
179 std::vector<unsigned int> step(
Nchains);
191 for (
int m = 0; m <
Nchains; m++)
193 TChain* Chain =
new TChain(
"posteriors");
196 nEntries[m] =
static_cast<unsigned int>(Chain->GetEntries());
199 BurnIn[m] = nEntries[m]/5;
202 TObjArray* brlis = Chain->GetListOfBranches();
205 nBranches[m] = brlis->GetEntries();
210 Chain->SetBranchStatus(
"*",
false);
214 for (
int i = 0; i < nBranches[m]; i++)
217 TBranch* br =
static_cast<TBranch *
>(brlis->At(i));
222 TString bname = br->GetName();
225 if (bname ==
"step") {
226 Chain->SetBranchStatus(bname,
true);
227 Chain->SetBranchAddress(bname, &step[m]);
230 else if (bname.BeginsWith(
"PCA_") || bname.BeginsWith(
"accProb") || bname.BeginsWith(
"stepTime") )
241 if(bname.BeginsWith(
"LogL"))
250 Chain->SetBranchStatus(bname,
true);
260 if(nBranches[m] != nBranches[0])
262 MACH3LOG_ERROR(
"Ups, something went wrong, chain {} called {} has {} branches, while 0 called {} has {} branches", m,
MCMCFile[m], nBranches[m],
MCMCFile[0], nBranches[0]);
263 MACH3LOG_ERROR(
"All chains should have the same number of branches");
271 for(
int i = 0; i <
Ntoys; i++)
275 for(
int j = 0; j <
nDraw; j++)
284 double* branch_values =
new double[
nDraw]();
285 for (
int j = 0; j <
nDraw; ++j)
287 Chain->SetBranchAddress(
BranchNames[j].Data(), &branch_values[j]);
291 if(BurnIn[m] >= nEntries[m])
293 MACH3LOG_ERROR(
"You are running on a chain shorter than BurnIn cut");
294 MACH3LOG_ERROR(
"Number of entries {} BurnIn cut {}", nEntries[m], BurnIn[m]);
300 for (
int i = 0; i <
Ntoys; i++)
303 int entry = int(nEntries[m]*rnd->Rndm());
305 Chain->GetEntry(entry);
309 if (step[m] < BurnIn[m])
322 for (
int j = 0; j <
nDraw; ++j) {
323 Draws[m][i][j] = branch_values[j];
329 delete[] branch_values;
336 #pragma omp parallel for
338 for(
int j = 0; j <
nDraw; j++)
341 std::vector<double> TempDraws(
static_cast<size_t>(
Ntoys) *
Nchains);
342 for(
int m = 0; m <
Nchains; m++)
344 for(
int i = 0; i <
Ntoys; i++)
347 TempDraws[im] =
Draws[m][i][j];
354 #pragma omp parallel for collapse(3)
356 for(
int m = 0; m <
Nchains; m++)
358 for(
int i = 0; i <
Ntoys; i++)
360 for(
int j = 0; j <
nDraw; j++)
367 MACH3LOG_INFO(
"Finished calculating Toys, it took {:.2f}s to finish", clock.RealTime());
398 for (
int m = 0; m <
Nchains; ++m)
405 for (
int j = 0; j <
nDraw; ++j)
456 #pragma omp for collapse(2)
460 for (
int m = 0; m <
Nchains; ++m)
462 for (
int j = 0; j <
nDraw; ++j)
464 for(
int i = 0; i <
Ntoys; i++)
478 for (
int j = 0; j <
nDraw; ++j)
480 for (
int m = 0; m <
Nchains; ++m)
491 #pragma omp for collapse(2)
494 for (
int m = 0; m <
Nchains; ++m)
496 for (
int j = 0; j <
nDraw; ++j)
498 for(
int i = 0; i <
Ntoys; i++)
512 for (
int j = 0; j <
nDraw; ++j)
514 for (
int m = 0; m <
Nchains; ++m)
526 for (
int j = 0; j <
nDraw; ++j)
536 for (
int m = 0; m <
Nchains; ++m)
549 for (
int j = 0; j <
nDraw; ++j)
559 for (
int j = 0; j <
nDraw; ++j)
573 for (
int j = 0; j <
nDraw; ++j)
587 MACH3LOG_INFO(
"Finished calculating RHat, it took {:.2f}s to finish", clock.RealTime());
594 #pragma GCC diagnostic ignored "-Wfloat-conversion"
596 std::string NameTemp =
"";
600 for (
int i = 0; i <
Nchains; i++)
604 while (temp.find(
".root") != std::string::npos) {
605 temp = temp.substr(0, temp.find(
".root"));
608 const auto slash = temp.find_last_of(
"/\\");
609 if (slash != std::string::npos) {
610 temp = temp.substr(slash + 1);
612 NameTemp = NameTemp + temp +
"_";
616 NameTemp = std::to_string(
Nchains) +
"Chains" +
"_";
618 NameTemp +=
"diag.root";
620 TFile *DiagFile =
M3::Open(NameTemp,
"recreate", __FILE__, __LINE__);
623 TH1D *StandardDeviationGlobalPlot =
new TH1D(
"StandardDeviationGlobalPlot",
"StandardDeviationGlobalPlot",
nDraw, 0,
nDraw);
624 TH1D *BetweenChainVariancePlot =
new TH1D(
"BetweenChainVariancePlot",
"BetweenChainVariancePlot",
nDraw, 0,
nDraw);
625 TH1D *MarginalPosteriorVariancePlot =
new TH1D(
"MarginalPosteriorVariancePlot",
"MarginalPosteriorVariancePlot",
nDraw, 0,
nDraw);
626 TH1D *RhatPlot =
new TH1D(
"RhatPlot",
"RhatPlot", 200, 0, 2);
627 TH1D *EffectiveSampleSizePlot =
new TH1D(
"EffectiveSampleSizePlot",
"EffectiveSampleSizePlot", 400, 0, 10000);
629 TH1D *StandardDeviationGlobalFoldedPlot =
new TH1D(
"StandardDeviationGlobalFoldedPlot",
"StandardDeviationGlobalFoldedPlot",
nDraw, 0,
nDraw);
630 TH1D *BetweenChainVarianceFoldedPlot =
new TH1D(
"BetweenChainVarianceFoldedPlot",
"BetweenChainVarianceFoldedPlot",
nDraw, 0,
nDraw);
631 TH1D *MarginalPosteriorVarianceFoldedPlot =
new TH1D(
"MarginalPosteriorVarianceFoldedPlot",
"MarginalPosteriorVarianceFoldedPlot",
nDraw, 0,
nDraw);
632 TH1D *RhatFoldedPlot =
new TH1D(
"RhatFoldedPlot",
"RhatFoldedPlot", 200, 0, 2);
633 TH1D *EffectiveSampleSizeFoldedPlot =
new TH1D(
"EffectiveSampleSizeFoldedPlot",
"EffectiveSampleSizeFoldedPlot", 400, 0, 10000);
635 TH1D *RhatLogPlot =
new TH1D(
"RhatLogPlot",
"RhatLogPlot", 200, 0, 2);
636 TH1D *RhatFoldedLogPlot =
new TH1D(
"RhatFoldedLogPlot",
"RhatFoldedLogPlot", 200, 0, 2);
639 int CiteriumFolded = 0;
640 for(
int j = 0; j <
nDraw; j++)
648 RhatPlot->Fill(
RHat[j]);
650 if(
RHat[j] > 1.1) Criterium++;
662 RhatLogPlot->Fill(
RHat[j]);
667 MACH3LOG_WARN(
"Number of parameters which has R hat greater than 1.1 is {}({:.2f}%) while for R hat folded {}({:.2f}%)", Criterium, 100*
double(Criterium)/
double(
nDraw), CiteriumFolded, 100*
double(CiteriumFolded)/
double(
nDraw));
668 for(
int j = 0; j <
nDraw; j++)
675 StandardDeviationGlobalPlot->Write();
676 BetweenChainVariancePlot->Write();
677 MarginalPosteriorVariancePlot->Write();
679 EffectiveSampleSizePlot->Write();
681 StandardDeviationGlobalFoldedPlot->Write();
682 BetweenChainVarianceFoldedPlot->Write();
683 MarginalPosteriorVarianceFoldedPlot->Write();
684 RhatFoldedPlot->Write();
685 EffectiveSampleSizeFoldedPlot->Write();
687 RhatLogPlot->Write();
688 RhatFoldedLogPlot->Write();
691 auto TempCanvas = std::make_unique<TCanvas>(
"Canvas",
"Canvas", 1024, 1024);
692 gStyle->SetOptStat(0);
693 TempCanvas->SetGridx();
694 TempCanvas->SetGridy();
697 auto TempLine = std::make_unique<TLine>(0, 0, 0, 0);
698 TempLine->SetLineColor(kBlack);
700 RhatPlot->GetXaxis()->SetTitle(
"R hat");
701 RhatPlot->SetLineColor(kRed);
702 RhatPlot->SetFillColor(kRed);
703 RhatFoldedPlot->SetLineColor(kBlue);
704 RhatFoldedPlot->SetFillColor(kBlue);
706 TLegend *Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
707 Legend->SetTextSize(0.04);
708 Legend->SetFillColor(0);
709 Legend->SetFillStyle(0);
710 Legend->SetLineWidth(0);
711 Legend->SetLineColor(0);
713 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
Ntoys,
Nchains),
"");
714 Legend->AddEntry(RhatPlot,
"Rhat Gelman 2013",
"l");
715 Legend->AddEntry(RhatFoldedPlot,
"Rhat-Folded Gelman 2021",
"l");
718 RhatFoldedPlot->Draw(
"same");
719 Legend->Draw(
"same");
720 TempCanvas->Write(
"Rhat");
725 RhatLogPlot->GetXaxis()->SetTitle(
"R hat for LogL");
726 RhatLogPlot->SetLineColor(kRed);
727 RhatLogPlot->SetFillColor(kRed);
728 RhatFoldedLogPlot->SetLineColor(kBlue);
729 RhatFoldedLogPlot->SetFillColor(kBlue);
731 Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
732 Legend->SetTextSize(0.04);
733 Legend->SetFillColor(0);
734 Legend->SetFillStyle(0);
735 Legend->SetLineWidth(0);
736 Legend->SetLineColor(0);
738 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
Ntoys,
Nchains),
"");
739 Legend->AddEntry(RhatLogPlot,
"Rhat Gelman 2013",
"l");
740 Legend->AddEntry(RhatFoldedLogPlot,
"Rhat-Folded Gelman 2021",
"l");
743 RhatFoldedLogPlot->Draw(
"same");
744 Legend->Draw(
"same");
745 TempCanvas->Write(
"RhatLog");
750 EffectiveSampleSizePlot->GetXaxis()->SetTitle(
"S_{eff, BDA2}");
751 EffectiveSampleSizePlot->SetLineColor(kRed);
752 EffectiveSampleSizeFoldedPlot->SetLineColor(kBlue);
754 Legend =
new TLegend(0.45, 0.6, 0.9, 0.9);
755 Legend->SetTextSize(0.03);
756 Legend->SetFillColor(0);
757 Legend->SetFillStyle(0);
758 Legend->SetLineWidth(0);
759 Legend->SetLineColor(0);
761 const double Mean1 = EffectiveSampleSizePlot->GetMean();
762 const double RMS1 = EffectiveSampleSizePlot->GetRMS();
763 const double Mean2 = EffectiveSampleSizeFoldedPlot->GetMean();
764 const double RMS2 = EffectiveSampleSizeFoldedPlot->GetRMS();
766 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
Ntoys,
Nchains),
"");
767 Legend->AddEntry(EffectiveSampleSizePlot, Form(
"S_{eff, BDA2} #mu = %.2f, #sigma = %.2f",Mean1 ,RMS1),
"l");
768 Legend->AddEntry(EffectiveSampleSizeFoldedPlot, Form(
"S_{eff, BDA2} Folded, #mu = %.2f, #sigma = %.2f",Mean2 ,RMS2),
"l");
770 EffectiveSampleSizePlot->Draw();
771 EffectiveSampleSizeFoldedPlot->Draw(
"same");
772 Legend->Draw(
"same");
773 TempCanvas->Write(
"EffectiveSampleSize");
776 delete StandardDeviationGlobalPlot;
777 delete BetweenChainVariancePlot;
778 delete MarginalPosteriorVariancePlot;
780 delete EffectiveSampleSizePlot;
782 delete StandardDeviationGlobalFoldedPlot;
783 delete BetweenChainVarianceFoldedPlot;
784 delete MarginalPosteriorVarianceFoldedPlot;
785 delete RhatFoldedPlot;
786 delete EffectiveSampleSizeFoldedPlot;
791 delete RhatFoldedLogPlot;
819 for(
int m = 0; m <
Nchains; m++)
821 for(
int i = 0; i <
Ntoys; i++)
823 delete[]
Draws[m][i];
848 std::sort(arr, arr+size);
851 return (arr[(size-1)/2] + arr[size/2])/2.0;
858 if(std::isnan(var) || !std::isfinite(var)) var = cap;
#define _MaCh3_Safe_Include_Start_
KS: Avoiding warning checking for headers.
#define _MaCh3_Safe_Include_End_
KS: Restore warning checking after including external headers.
#define MACH3LOG_CRITICAL
void SetMaCh3LoggerFormat()
Set messaging format of the logger.
int main(int argc, char *argv[])
double * StandardDeviationGlobalFolded
void CapVariable(double var, double cap)
double * EffectiveSampleSizeFolded
double * BetweenChainVarianceFolded
std::vector< bool > ValidPar
double * BetweenChainVariance
double * MeanGlobalFolded
double ** StandardDeviation
double * StandardDeviationGlobal
double * MarginalPosteriorVariance
double ** StandardDeviationFolded
std::vector< TString > BranchNames
std::vector< std::string > MCMCFile
double CalcMedian(double arr[], int size)
double * EffectiveSampleSize
double * MarginalPosteriorVarianceFolded
Custom exception class used throughout MaCh3.
TFile * Open(const std::string &Name, const std::string &Type, const std::string &File, const int Line)
Opens a ROOT file with the given name and mode.
void PrintProgressBar(const Long64_t Done, const Long64_t All)
KS: Simply print progress bar.
void MaCh3Welcome()
KS: Prints welcome message with MaCh3 logo.