17#include "TStopwatch.h"
83int main(
int argc,
char *argv[]) {
105 MACH3LOG_ERROR(
"./RHat NThin MCMCchain_1.root MCMCchain_2.root MCMCchain_3.root ... [how many you like]");
109 NThin = atoi(argv[1]);
112 for (
int i = 2; i < argc; i++)
114 MCMCFile.push_back(std::string(argv[i]));
121 MACH3LOG_WARN(
"Gelman is going to be sad :(. He suggested you should use more than one chain (at least 4). Code works fine for one chain, however, estimator might be biased.");
122 MACH3LOG_WARN(
"Multiple chains are more likely to reveal multimodality and poor adaptation or mixing:");
151 std::vector<int> BurnIn(
Nchains);
152 std::vector<int> nEntries(
Nchains);
153 std::vector<int> nBranches(
Nchains);
154 std::vector<int> step(
Nchains);
166 for (
int m = 0; m <
Nchains; m++)
168 TChain* Chain =
new TChain(
"posteriors");
171 nEntries[m] = int(Chain->GetEntries());
179 BurnIn[m] = nEntries[m]/5;
182 TObjArray* brlis = Chain->GetListOfBranches();
185 nBranches[m] = brlis->GetEntries();
190 Chain->SetBranchStatus(
"*",
false);
194 for (
int i = 0; i < nBranches[m]; i++)
197 TBranch* br =
static_cast<TBranch *
>(brlis->At(i));
202 TString bname = br->GetName();
205 if (bname ==
"step") {
206 Chain->SetBranchStatus(bname,
true);
207 Chain->SetBranchAddress(bname, &step[m]);
210 else if (bname.BeginsWith(
"PCA_") || bname.BeginsWith(
"accProb") || bname.BeginsWith(
"stepTime") )
221 if(bname.BeginsWith(
"LogL"))
230 Chain->SetBranchStatus(bname,
true);
245 for (
int id = 0;
id <
nDraw; ++id)
259 if(nBranches[m] != nBranches[0])
261 MACH3LOG_ERROR(
"Ups, something went wrong, chain {} called {} has {} branches, while 0 called {} has {} branches", m,
MCMCFile[m], nBranches[m],
MCMCFile[0], nBranches[0]);
262 MACH3LOG_ERROR(
"All chains should have the same number of branches");
269 double* branch_values =
new double[
nDraw]();
270 for (
int id = 0;
id <
nDraw; ++id)
272 Chain->SetBranchAddress(
BranchNames[
id].Data(), &branch_values[id]);
276 if(BurnIn[m] >= nEntries[m])
278 MACH3LOG_ERROR(
"You are running on a chain shorter than BurnIn cut");
279 MACH3LOG_ERROR(
"Number of entries {} BurnIn cut {}", nEntries[m], BurnIn[m]);
291 Chain->GetEntry(entry);
295 if (step[m] < BurnIn[m])
309 for (
int j = 0; j <
nDraw; ++j)
312 S2_global[j] += branch_values[j]*branch_values[j];
314 S2_chain[m][j] += branch_values[j]*branch_values[j];
325 delete[] branch_values;
330 MACH3LOG_INFO(
"Finished calculating Toys, it took {:.2f}s to finish", clock.RealTime());
350 for (
int m = 0; m <
Nchains; ++m)
355 for (
int j = 0; j <
nDraw; ++j)
402 for (
int m = 0; m <
Nchains; ++m)
404 for (
int j = 0; j <
nDraw; ++j)
415 for (
int j = 0; j <
nDraw; ++j)
417 for (
int m = 0; m <
Nchains; ++m)
428 for (
int j = 0; j <
nDraw; ++j)
437 for (
int m = 0; m <
Nchains; ++m)
449 for (
int j = 0; j <
nDraw; ++j)
458 for (
int j = 0; j <
nDraw; ++j)
470 for (
int j = 0; j <
nDraw; ++j)
482 MACH3LOG_INFO(
"Finished calculating RHat, it took {:.2f}s to finish", clock.RealTime());
489 #pragma GCC diagnostic ignored "-Wfloat-conversion"
491 std::string NameTemp =
"";
495 for (
int i = 0; i <
Nchains; i++)
499 while (temp.find(
".root") != std::string::npos) {
500 temp = temp.substr(0, temp.find(
".root"));
503 NameTemp = NameTemp + temp +
"_";
507 NameTemp = std::to_string(
Nchains) +
"Chains" +
"_";
509 NameTemp +=
"diag.root";
511 TFile* DiagFile =
new TFile(NameTemp.c_str(),
"recreate");
515 TH1D *StandardDeviationGlobalPlot =
new TH1D(
"StandardDeviationGlobalPlot",
"StandardDeviationGlobalPlot",
nDraw, 0,
nDraw);
516 TH1D *BetweenChainVariancePlot =
new TH1D(
"BetweenChainVariancePlot",
"BetweenChainVariancePlot",
nDraw, 0,
nDraw);
517 TH1D *MarginalPosteriorVariancePlot =
new TH1D(
"MarginalPosteriorVariancePlot",
"MarginalPosteriorVariancePlot",
nDraw, 0,
nDraw);
518 TH1D *RhatPlot =
new TH1D(
"RhatPlot",
"RhatPlot", 200, 0, 2);
519 TH1D *EffectiveSampleSizePlot =
new TH1D(
"EffectiveSampleSizePlot",
"EffectiveSampleSizePlot", 400, 0, 10000);
521 TH1D *RhatLogPlot =
new TH1D(
"RhatLogPlot",
"RhatLogPlot", 200, 0, 2);
524 for(
int j = 0; j <
nDraw; j++)
532 RhatPlot->Fill(
RHat[j]);
534 if(
RHat[j] > 1.1) Criterium++;
538 RhatLogPlot->Fill(
RHat[j]);
542 MACH3LOG_WARN(
"Number of parameters which has R hat greater than 1.1 is {}({:.2f}%)", Criterium, 100*
double(Criterium)/
double(
nDraw));
543 for(
int j = 0; j <
nDraw; j++)
550 StandardDeviationGlobalPlot->Write();
551 BetweenChainVariancePlot->Write();
552 MarginalPosteriorVariancePlot->Write();
554 EffectiveSampleSizePlot->Write();
556 RhatLogPlot->Write();
559 auto TempCanvas = std::make_unique<TCanvas>(
"Canvas",
"Canvas", 1024, 1024);
560 gStyle->SetOptStat(0);
561 TempCanvas->SetGridx();
562 TempCanvas->SetGridy();
565 auto TempLine = std::make_unique<TLine>(0, 0, 0, 0);
566 TempLine->SetLineColor(kBlack);
568 RhatPlot->GetXaxis()->SetTitle(
"R hat");
569 RhatPlot->SetLineColor(kRed);
570 RhatPlot->SetFillColor(kRed);
572 TLegend *Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
573 Legend->SetTextSize(0.04);
574 Legend->SetFillColor(0);
575 Legend->SetFillStyle(0);
576 Legend->SetLineWidth(0);
577 Legend->SetLineColor(0);
579 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
TotToys,
Nchains),
"");
580 Legend->AddEntry(RhatPlot,
"Rhat Gelman 2013",
"l");
583 Legend->Draw(
"same");
584 TempCanvas->Write(
"Rhat");
589 RhatLogPlot->GetXaxis()->SetTitle(
"R hat for LogL");
590 RhatLogPlot->SetLineColor(kRed);
591 RhatLogPlot->SetFillColor(kRed);
593 Legend =
new TLegend(0.55, 0.6, 0.9, 0.9);
594 Legend->SetTextSize(0.04);
595 Legend->SetFillColor(0);
596 Legend->SetFillStyle(0);
597 Legend->SetLineWidth(0);
598 Legend->SetLineColor(0);
600 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
TotToys,
Nchains),
"");
601 Legend->AddEntry(RhatLogPlot,
"Rhat Gelman 2013",
"l");
604 Legend->Draw(
"same");
605 TempCanvas->Write(
"RhatLog");
610 EffectiveSampleSizePlot->GetXaxis()->SetTitle(
"S_{eff, BDA2}");
611 EffectiveSampleSizePlot->SetLineColor(kRed);
613 Legend =
new TLegend(0.45, 0.6, 0.9, 0.9);
614 Legend->SetTextSize(0.03);
615 Legend->SetFillColor(0);
616 Legend->SetFillStyle(0);
617 Legend->SetLineWidth(0);
618 Legend->SetLineColor(0);
620 const double Mean1 = EffectiveSampleSizePlot->GetMean();
621 const double RMS1 = EffectiveSampleSizePlot->GetRMS();
623 Legend->AddEntry(TempLine.get(), Form(
"Number of throws=%.0i, Number of chains=%.1i",
TotToys,
Nchains),
"");
624 Legend->AddEntry(EffectiveSampleSizePlot, Form(
"S_{eff, BDA2} #mu = %.2f, #sigma = %.2f",Mean1 ,RMS1),
"l");
626 EffectiveSampleSizePlot->Draw();
627 Legend->Draw(
"same");
628 TempCanvas->Write(
"EffectiveSampleSize");
631 delete StandardDeviationGlobalPlot;
632 delete BetweenChainVariancePlot;
633 delete MarginalPosteriorVariancePlot;
635 delete EffectiveSampleSizePlot;
660 for(
int m = 0; m <
Nchains; m++)
682 std::sort(arr, arr+
size);
685 return (arr[(
size-1)/2] + arr[
size/2])/2.0;
692 if(std::isnan(var) || !std::isfinite(var)) var = cap;
#define _MaCh3_Safe_Include_Start_
KS: Avoiding warning checking for headers.
#define _MaCh3_Safe_Include_End_
KS: Restore warning checking after including external headers.
#define MACH3LOG_CRITICAL
void SetMaCh3LoggerFormat()
Set messaging format of the logger.
int main(int argc, char *argv[])
void CapVariable(double var, double cap)
std::vector< bool > ValidPar
double * BetweenChainVariance
double ** StandardDeviation
double * StandardDeviationGlobal
double * MarginalPosteriorVariance
std::vector< TString > BranchNames
std::vector< std::string > MCMCFile
double CalcMedian(double arr[], int size)
double * EffectiveSampleSize
Custom exception class for MaCh3 errors.
void PrintProgressBar(const Long64_t Done, const Long64_t All)
KS: Simply print progress bar.
void MaCh3Welcome()
KS: Prints welcome message with MaCh3 logo.