MaCh3  2.2.3
Reference Guide
RHat_HighMem.cpp
Go to the documentation of this file.
1 // MaCh3 includes
2 #include "Manager/Manager.h"
4 
6 // ROOT includes
7 #include "TObjArray.h"
8 #include "TChain.h"
9 #include "TFile.h"
10 #include "TBranch.h"
11 #include "TCanvas.h"
12 #include "TLine.h"
13 #include "TLegend.h"
14 #include "TString.h"
15 #include "TH1.h"
16 #include "TRandom3.h"
17 #include "TStopwatch.h"
18 #include "TColor.h"
19 #include "TStyle.h"
20 #include "TROOT.h"
22 
31 
32 // *******************
33 int Ntoys;
34 int Nchains;
35 
36 int nDraw;
37 
38 std::vector<TString> BranchNames;
39 std::vector<std::string> MCMCFile;
40 std::vector<bool> ValidPar;
41 
42 double ***Draws;
43 
44 double** Mean;
46 
47 double* MeanGlobal;
49 
52 double* RHat;
54 
55 double ***DrawsFolded;
56 double* MedianArr;
57 
58 double** MeanFolded;
60 
63 
66 double* RHatFolded;
68 // *******************
69 void PrepareChains();
70 void InitialiseArrays();
71 
72 void RunDiagnostic();
73 void CalcRhat();
74 
75 void SaveResults();
76 void DestroyArrays();
77 double CalcMedian(double arr[], int size);
78 
79 void CapVariable(double var, double cap);
80 
81 // *******************
82 int main(int argc, char *argv[]) {
83 // *******************
84 
87 
88  Draws = nullptr;
89  Mean = nullptr;
90  StandardDeviation = nullptr;
91 
92  MeanGlobal = nullptr;
93  StandardDeviationGlobal = nullptr;
94 
95  BetweenChainVariance = nullptr;
96  MarginalPosteriorVariance = nullptr;
97  RHat = nullptr;
98  EffectiveSampleSize = nullptr;
99 
100  DrawsFolded = nullptr;
101  MedianArr = nullptr;
102  MeanFolded = nullptr;
103  StandardDeviationFolded = nullptr;
104 
105  MeanGlobalFolded = nullptr;
107 
108  BetweenChainVarianceFolded = nullptr;
110  RHatFolded = nullptr;
111  EffectiveSampleSizeFolded = nullptr;
112 
113  Nchains = 0;
114 
115  if (argc == 1 || argc == 2)
116  {
117  MACH3LOG_ERROR("Wrong arguments");
118  MACH3LOG_ERROR("./RHat Ntoys MCMCchain_1.root MCMCchain_2.root MCMCchain_3.root ... [how many you like]");
119  throw MaCh3Exception(__FILE__ , __LINE__ );
120  }
121 
122  Ntoys = atoi(argv[1]);
123 
124  //KS Gelman suggests to diagnose on more than one chain
125  for (int i = 2; i < argc; i++)
126  {
127  MCMCFile.push_back(std::string(argv[i]));
128  MACH3LOG_INFO("Adding file: {}", MCMCFile.back());
129  Nchains++;
130  }
131 
132  if(Ntoys < 1)
133  {
134  MACH3LOG_ERROR("You specified {} specify larger greater than 0", Ntoys);
135  throw MaCh3Exception(__FILE__ , __LINE__ );
136  }
137 
138  if(Nchains == 1)
139  {
140  MACH3LOG_WARN("Gelman is going to be sad :(. He suggested you should use more than one chain (at least 4). Code works fine for one chain, however, estimator might be biased.");
141  MACH3LOG_WARN("Multiple chains are more likely to reveal multimodality and poor adaptation or mixing:");
142  }
143  MACH3LOG_INFO("Diagnosing {} chains, with {} toys", Nchains, Ntoys);
144 
145  PrepareChains();
146 
148 
149  //KS: Main function
150  RunDiagnostic();
151 
152  SaveResults();
153 
154  DestroyArrays();
155 
156  return 0;
157 }
158 
159 // *******************
160 // Load chain and prepare toys
162 // *******************
163  auto rnd = std::make_unique<TRandom3>(0);
164 
165  MACH3LOG_INFO("Generating {}", Ntoys);
166 
167  TStopwatch clock;
168  clock.Start();
169 
170  std::vector<int> BurnIn(Nchains);
171  std::vector<int> nEntries(Nchains);
172  std::vector<int> nBranches(Nchains);
173  std::vector<int> step(Nchains);
174 
175  Draws = new double**[Nchains]();
176  DrawsFolded = new double**[Nchains]();
177 
178  // KS: This can reduce time necessary for caching even by half
179  #ifdef MULTITHREAD
180  //ROOT::EnableImplicitMT();
181  #endif
182 
183  // Open the Chain
184  //It is tempting to multithread here but unfortunately, ROOT files are not thread safe :(
185  for (int m = 0; m < Nchains; m++)
186  {
187  TChain* Chain = new TChain("posteriors");
188  Chain->Add(MCMCFile[m].c_str());
189  MACH3LOG_INFO("On file: {}", MCMCFile[m].c_str());
190  nEntries[m] = int(Chain->GetEntries());
191 
192  // Set the step cut to be 20%
193  BurnIn[m] = nEntries[m]/5;
194 
195  // Get the list of branches
196  TObjArray* brlis = Chain->GetListOfBranches();
197 
198  // Get the number of branches
199  nBranches[m] = brlis->GetEntries();
200 
201  if(m == 0) BranchNames.reserve(nBranches[m]);
202 
203  // Set all the branches to off
204  Chain->SetBranchStatus("*", false);
205 
206  // Loop over the number of branches
207  // Find the name and how many of each systematic we have
208  for (int i = 0; i < nBranches[m]; i++)
209  {
210  // Get the TBranch and its name
211  TBranch* br = static_cast<TBranch *>(brlis->At(i));
212  if(!br){
213  MACH3LOG_ERROR("Invalid branch at position {}", i);
214  throw MaCh3Exception(__FILE__,__LINE__);
215  }
216  TString bname = br->GetName();
217 
218  // Read in the step
219  if (bname == "step") {
220  Chain->SetBranchStatus(bname, true);
221  Chain->SetBranchAddress(bname, &step[m]);
222  }
223  //Count all branches
224  else if (bname.BeginsWith("PCA_") || bname.BeginsWith("accProb") || bname.BeginsWith("stepTime") )
225  {
226  continue;
227  }
228  else
229  {
230  //KS: Save branch name only for one chain, we assume all chains have the same branches, otherwise this doesn't make sense either way
231  if(m == 0)
232  {
233  BranchNames.push_back(bname);
234  //KS: We calculate R Hat also for LogL, just in case, however we plot them separately
235  if(bname.BeginsWith("LogL"))
236  {
237  ValidPar.push_back(false);
238  }
239  else
240  {
241  ValidPar.push_back(true);
242  }
243  }
244  Chain->SetBranchStatus(bname, true);
245  MACH3LOG_DEBUG("{}", bname);
246  }
247  }
248 
249  if(m == 0) nDraw = int(BranchNames.size());
250 
251  //TN: Qualitatively faster sanity check, with the very same outcome (all chains have the same #branches)
252  if(m > 0)
253  {
254  if(nBranches[m] != nBranches[0])
255  {
256  MACH3LOG_ERROR("Ups, something went wrong, chain {} called {} has {} branches, while 0 called {} has {} branches", m, MCMCFile[m], nBranches[m], MCMCFile[0], nBranches[0]);
257  MACH3LOG_ERROR("All chains should have the same number of branches");
258  throw MaCh3Exception(__FILE__ , __LINE__ );
259  }
260  }
261 
262  //TN: move the Draws here, so we need to iterate over every chain only once
263  Draws[m] = new double*[Ntoys]();
264  DrawsFolded[m] = new double*[Ntoys]();
265  for(int i = 0; i < Ntoys; i++)
266  {
267  Draws[m][i] = new double[nDraw]();
268  DrawsFolded[m][i] = new double[nDraw]();
269  for(int j = 0; j < nDraw; j++)
270  {
271  Draws[m][i][j] = 0.;
272  DrawsFolded[m][i][j] = 0.;
273  }
274  }
275 
276  // MJR: array to hold branch values; SetBranchAddress in every step is very
277  // expensive, so doing it once only here saves time
278  double* branch_values = new double[nDraw]();
279  for (int j = 0; j < nDraw; ++j)
280  {
281  Chain->SetBranchAddress(BranchNames[j].Data(), &branch_values[j]);
282  }
283 
284  //TN: move looping over toys here, so we don't need to loop over chains more than once
285  if(BurnIn[m] >= nEntries[m])
286  {
287  MACH3LOG_ERROR("You are running on a chain shorter than BurnIn cut");
288  MACH3LOG_ERROR("Number of entries {} BurnIn cut {}", nEntries[m], BurnIn[m]);
289  MACH3LOG_ERROR("You will run into the infinite loop");
290  MACH3LOG_ERROR("You can make a new chain or modify BurnIn cut");
291  throw MaCh3Exception(__FILE__ , __LINE__ );
292  }
293 
294  for (int i = 0; i < Ntoys; i++)
295  {
296  // Get a random entry after burn in
297  int entry = int(nEntries[m]*rnd->Rndm());
298 
299  Chain->GetEntry(entry);
300 
301  // If we have combined chains by hadd need to check the step in the chain
302  // Note, entry is not necessarily the same as the step due to merged ROOT files, so can't choose an entry in the range BurnIn - nEntries :(
303  if (step[m] < BurnIn[m])
304  {
305  i--;
306  continue;
307  }
308 
309  // Output some info for the user
310  if (Ntoys > 10 && i % (Ntoys/10) == 0) {
311  MaCh3Utils::PrintProgressBar(i+m*Ntoys, static_cast<Long64_t>(Ntoys)*Nchains);
312  MACH3LOG_DEBUG("Getting random entry {}", entry);
313  }
314 
315  // Set the branch addresses for params
316  for (int j = 0; j < nDraw; ++j)
317  {
318  Draws[m][i][j] = branch_values[j];
319  }
320 
321  }//end loop over toys
322 
323  //TN: There, we now don't need to keep the chain in memory anymore
324  delete Chain;
325  delete[] branch_values;
326  }
327 
328  //KS: Now prepare folded draws, quoting Gelman
329  //"We propose to report the maximum of rank normalized split-Rb and rank normalized folded-split-Rb for each parameter"
330  MedianArr = new double[nDraw]();
331  #ifdef MULTITHREAD
332  #pragma omp parallel for
333  #endif
334  for(int j = 0; j < nDraw; j++)
335  {
336  MedianArr[j] = 0.;
337  std::vector<double> TempDraws(static_cast<size_t>(Ntoys) * Nchains);
338  for(int m = 0; m < Nchains; m++)
339  {
340  for(int i = 0; i < Ntoys; i++)
341  {
342  const int im = i+m;
343  TempDraws[im] = Draws[m][i][j];
344  }
345  }
346  MedianArr[j] = CalcMedian(TempDraws.data(), Ntoys*Nchains);
347  }
348 
349  #ifdef MULTITHREAD
350  #pragma omp parallel for collapse(3)
351  #endif
352  for(int m = 0; m < Nchains; m++)
353  {
354  for(int i = 0; i < Ntoys; i++)
355  {
356  for(int j = 0; j < nDraw; j++)
357  {
358  DrawsFolded[m][i][j] = std::fabs(Draws[m][i][j] - MedianArr[j]);
359  }
360  }
361  }
362  clock.Stop();
363  MACH3LOG_INFO("Finished calculating Toys, it took {:.2f}s to finish", clock.RealTime());
364 }
365 
366 // *******************
367 // Create all arrays we are going to use later
369 // *******************
370 
371  MACH3LOG_INFO("Initialising arrays");
372  Mean = new double*[Nchains]();
373  StandardDeviation = new double*[Nchains]();
374 
375  MeanGlobal = new double[nDraw]();
376  StandardDeviationGlobal = new double[nDraw]();
377  BetweenChainVariance = new double[nDraw]();
378 
379  MarginalPosteriorVariance = new double[nDraw]();
380  RHat = new double[nDraw]();
381  EffectiveSampleSize = new double[nDraw]();
382 
383  MeanFolded = new double*[Nchains]();
384  StandardDeviationFolded = new double*[Nchains]();
385 
386  MeanGlobalFolded = new double[nDraw]();
387  StandardDeviationGlobalFolded = new double[nDraw]();
388  BetweenChainVarianceFolded = new double[nDraw]();
389 
390  MarginalPosteriorVarianceFolded = new double[nDraw]();
391  RHatFolded = new double[nDraw]();
392  EffectiveSampleSizeFolded = new double[nDraw]();
393 
394  for (int m = 0; m < Nchains; ++m)
395  {
396  Mean[m] = new double[nDraw]();
397  StandardDeviation[m] = new double[nDraw]();
398 
399  MeanFolded[m] = new double[nDraw]();
400  StandardDeviationFolded[m] = new double[nDraw]();
401  for (int j = 0; j < nDraw; ++j)
402  {
403  Mean[m][j] = 0.;
404  StandardDeviation[m][j] = 0.;
405 
406  MeanFolded[m][j] = 0.;
407  StandardDeviationFolded[m][j] = 0.;
408  if(m == 0)
409  {
410  MeanGlobal[j] = 0.;
411  StandardDeviationGlobal[j] = 0.;
412  BetweenChainVariance[j] = 0.;
414  RHat[j] = 0.;
415  EffectiveSampleSize[j] = 0.;
416 
417  MeanGlobalFolded[j] = 0.;
421  RHatFolded[j] = 0.;
423  }
424  }
425  }
426 }
427 
428 // *******************
430 // *******************
431  CalcRhat();
432  //In case in future we expand this
433 }
434 
435 // *******************
436 //KS: Based on Gelman et. al. arXiv:1903.08008v5
437 // Probably most of it could be moved cleverly to MCMC Processor, keep it separate for now
438 void CalcRhat() {
439 // *******************
440 
441  TStopwatch clock;
442  clock.Start();
443 
444  //KS: Start parallel region
445  // If we would like to do this for thousands of chains we might consider using GPU for this
446  #ifdef MULTITHREAD
447  #pragma omp parallel
448  {
449  #endif
450 
451  #ifdef MULTITHREAD
452  #pragma omp for collapse(2)
453  #endif
454  //KS: loop over chains and draws are independent so might as well collapse for sweet cache hits
455  //Calculate the mean for each parameter within each considered chain
456  for (int m = 0; m < Nchains; ++m)
457  {
458  for (int j = 0; j < nDraw; ++j)
459  {
460  for(int i = 0; i < Ntoys; i++)
461  {
462  Mean[m][j] += Draws[m][i][j];
463  MeanFolded[m][j] += DrawsFolded[m][i][j];
464  }
465  Mean[m][j] = Mean[m][j]/Ntoys;
466  MeanFolded[m][j] = MeanFolded[m][j]/Ntoys;
467  }
468  }
469 
470  #ifdef MULTITHREAD
471  #pragma omp for
472  #endif
473  //Calculate the mean for each parameter global means we include information from several chains
474  for (int j = 0; j < nDraw; ++j)
475  {
476  for (int m = 0; m < Nchains; ++m)
477  {
478  MeanGlobal[j] += Mean[m][j];
479  MeanGlobalFolded[j] += MeanFolded[m][j];
480  }
481  MeanGlobal[j] = MeanGlobal[j]/Nchains;
483  }
484 
485 
486  #ifdef MULTITHREAD
487  #pragma omp for collapse(2)
488  #endif
489  //Calculate the standard deviation for each parameter within each considered chain
490  for (int m = 0; m < Nchains; ++m)
491  {
492  for (int j = 0; j < nDraw; ++j)
493  {
494  for(int i = 0; i < Ntoys; i++)
495  {
496  StandardDeviation[m][j] += (Draws[m][i][j] - Mean[m][j])*(Draws[m][i][j] - Mean[m][j]);
497  StandardDeviationFolded[m][j] += (DrawsFolded[m][i][j] - MeanFolded[m][j])*(DrawsFolded[m][i][j] - MeanFolded[m][j]);
498  }
499  StandardDeviation[m][j] = StandardDeviation[m][j]/(Ntoys-1);
501  }
502  }
503 
504  #ifdef MULTITHREAD
505  #pragma omp for
506  #endif
507  //Calculate the standard deviation for each parameter combining information from all chains
508  for (int j = 0; j < nDraw; ++j)
509  {
510  for (int m = 0; m < Nchains; ++m)
511  {
514  }
517  }
518 
519  #ifdef MULTITHREAD
520  #pragma omp for
521  #endif
522  for (int j = 0; j < nDraw; ++j)
523  {
524  //KS: This term only makes sense if we have at least 2 chains
525  if(Nchains == 1)
526  {
527  BetweenChainVariance[j] = 0.;
529  }
530  else
531  {
532  for (int m = 0; m < Nchains; ++m)
533  {
534  BetweenChainVariance[j] += ( Mean[m][j] - MeanGlobal[j])*( Mean[m][j] - MeanGlobal[j]);
536  }
539  }
540  }
541 
542  #ifdef MULTITHREAD
543  #pragma omp for
544  #endif
545  for (int j = 0; j < nDraw; ++j)
546  {
549  }
550 
551  #ifdef MULTITHREAD
552  #pragma omp for
553  #endif
554  //Finally calculate our estimator
555  for (int j = 0; j < nDraw; ++j)
556  {
559 
560  //KS: For flat params values can be crazy so cap at 0
561  CapVariable(RHat[j], 0);
562  CapVariable(RHatFolded[j], 0);
563  }
564 
565  #ifdef MULTITHREAD
566  #pragma omp for
567  #endif
568  //KS: Additionally calculates effective step size which is an estimate of the sample size required to achieve the same level of precision if that sample was a simple random sample.
569  for (int j = 0; j < nDraw; ++j)
570  {
573 
574  //KS: For flat params values can be crazy so cap at 0
577  }
578  #ifdef MULTITHREAD
579  } //End parallel region
580  #endif
581 
582  clock.Stop();
583  MACH3LOG_INFO("Finished calculating RHat, it took {:.2f}s to finish", clock.RealTime());
584 }
585 
586 
587 // *******************
588 void SaveResults() {
589 // *******************
590  #pragma GCC diagnostic ignored "-Wfloat-conversion"
591 
592  std::string NameTemp = "";
593  //KS: If we run over many many chains there is danger that name will be so absurdly long we run over system limit and job will be killed :(
594  if(Nchains < 5)
595  {
596  for (int i = 0; i < Nchains; i++)
597  {
598  std::string temp = MCMCFile[i];
599 
600  while (temp.find(".root") != std::string::npos) {
601  temp = temp.substr(0, temp.find(".root"));
602  }
603 
604  NameTemp = NameTemp + temp + "_";
605  }
606  }
607  else {
608  NameTemp = std::to_string(Nchains) + "Chains" + "_";
609  }
610  NameTemp += "diag.root";
611 
612  TFile* DiagFile = new TFile(NameTemp.c_str(), "recreate");
613 
614  DiagFile->cd();
615 
616  TH1D *StandardDeviationGlobalPlot = new TH1D("StandardDeviationGlobalPlot", "StandardDeviationGlobalPlot", nDraw, 0, nDraw);
617  TH1D *BetweenChainVariancePlot = new TH1D("BetweenChainVariancePlot", "BetweenChainVariancePlot", nDraw, 0, nDraw);
618  TH1D *MarginalPosteriorVariancePlot = new TH1D("MarginalPosteriorVariancePlot", "MarginalPosteriorVariancePlot", nDraw, 0, nDraw);
619  TH1D *RhatPlot = new TH1D("RhatPlot", "RhatPlot", 200, 0, 2);
620  TH1D *EffectiveSampleSizePlot = new TH1D("EffectiveSampleSizePlot", "EffectiveSampleSizePlot", 400, 0, 10000);
621 
622  TH1D *StandardDeviationGlobalFoldedPlot = new TH1D("StandardDeviationGlobalFoldedPlot", "StandardDeviationGlobalFoldedPlot", nDraw, 0, nDraw);
623  TH1D *BetweenChainVarianceFoldedPlot = new TH1D("BetweenChainVarianceFoldedPlot", "BetweenChainVarianceFoldedPlot", nDraw, 0, nDraw);
624  TH1D *MarginalPosteriorVarianceFoldedPlot = new TH1D("MarginalPosteriorVarianceFoldedPlot", "MarginalPosteriorVarianceFoldedPlot", nDraw, 0, nDraw);
625  TH1D *RhatFoldedPlot = new TH1D("RhatFoldedPlot", "RhatFoldedPlot", 200, 0, 2);
626  TH1D *EffectiveSampleSizeFoldedPlot = new TH1D("EffectiveSampleSizeFoldedPlot", "EffectiveSampleSizeFoldedPlot", 400, 0, 10000);
627 
628  TH1D *RhatLogPlot = new TH1D("RhatLogPlot", "RhatLogPlot", 200, 0, 2);
629  TH1D *RhatFoldedLogPlot = new TH1D("RhatFoldedLogPlot", "RhatFoldedLogPlot", 200, 0, 2);
630 
631  int Criterium = 0;
632  int CiteriumFolded = 0;
633  for(int j = 0; j < nDraw; j++)
634  {
635  //KS: Fill only valid parameters
636  if(ValidPar[j])
637  {
638  StandardDeviationGlobalPlot->Fill(j,StandardDeviationGlobal[j]);
639  BetweenChainVariancePlot->Fill(j,BetweenChainVariance[j]);
640  MarginalPosteriorVariancePlot->Fill(j,MarginalPosteriorVariance[j]);
641  RhatPlot->Fill(RHat[j]);
642  EffectiveSampleSizePlot->Fill(EffectiveSampleSize[j]);
643  if(RHat[j] > 1.1) Criterium++;
644 
645 
646  StandardDeviationGlobalFoldedPlot->Fill(j,StandardDeviationGlobalFolded[j]);
647  BetweenChainVarianceFoldedPlot->Fill(j,BetweenChainVarianceFolded[j]);
648  MarginalPosteriorVarianceFoldedPlot->Fill(j,MarginalPosteriorVarianceFolded[j]);
649  RhatFoldedPlot->Fill(RHatFolded[j]);
650  EffectiveSampleSizeFoldedPlot->Fill(EffectiveSampleSizeFolded[j]);
651  if(RHatFolded[j] > 1.1) CiteriumFolded++;
652  }
653  else
654  {
655  RhatLogPlot->Fill(RHat[j]);
656  RhatFoldedLogPlot->Fill(RHatFolded[j]);
657  }
658  }
659  //KS: We set criterium of 1.1 based on Gelman et al. (2003) Bayesian Data Analysis
660  MACH3LOG_WARN("Number of parameters which has R hat greater than 1.1 is {}({:.2f}%) while for R hat folded {}({:.2f}%)", Criterium, 100*double(Criterium)/double(nDraw), CiteriumFolded, 100*double(CiteriumFolded)/double(nDraw));
661  for(int j = 0; j < nDraw; j++)
662  {
663  if( (RHat[j] > 1.1 || RHatFolded[j] > 1.1) && ValidPar[j])
664  {
665  MACH3LOG_CRITICAL("Parameter {} has R hat higher than 1.1", BranchNames[j]);
666  }
667  }
668  StandardDeviationGlobalPlot->Write();
669  BetweenChainVariancePlot->Write();
670  MarginalPosteriorVariancePlot->Write();
671  RhatPlot->Write();
672  EffectiveSampleSizePlot->Write();
673 
674  StandardDeviationGlobalFoldedPlot->Write();
675  BetweenChainVarianceFoldedPlot->Write();
676  MarginalPosteriorVarianceFoldedPlot->Write();
677  RhatFoldedPlot->Write();
678  EffectiveSampleSizeFoldedPlot->Write();
679 
680  RhatLogPlot->Write();
681  RhatFoldedLogPlot->Write();
682 
683  //KS: Now we make fancy canvases, consider some function to have less copy pasting
684  auto TempCanvas = std::make_unique<TCanvas>("Canvas", "Canvas", 1024, 1024);
685  gStyle->SetOptStat(0);
686  TempCanvas->SetGridx();
687  TempCanvas->SetGridy();
688 
689  // Random line to write useful information to TLegend
690  auto TempLine = std::make_unique<TLine>(0, 0, 0, 0);
691  TempLine->SetLineColor(kBlack);
692 
693  RhatPlot->GetXaxis()->SetTitle("R hat");
694  RhatPlot->SetLineColor(kRed);
695  RhatPlot->SetFillColor(kRed);
696  RhatFoldedPlot->SetLineColor(kBlue);
697  RhatFoldedPlot->SetFillColor(kBlue);
698 
699  TLegend *Legend = new TLegend(0.55, 0.6, 0.9, 0.9);
700  Legend->SetTextSize(0.04);
701  Legend->SetFillColor(0);
702  Legend->SetFillStyle(0);
703  Legend->SetLineWidth(0);
704  Legend->SetLineColor(0);
705 
706  Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
707  Legend->AddEntry(RhatPlot, "Rhat Gelman 2013", "l");
708  Legend->AddEntry(RhatFoldedPlot, "Rhat-Folded Gelman 2021", "l");
709 
710  RhatPlot->Draw();
711  RhatFoldedPlot->Draw("same");
712  Legend->Draw("same");
713  TempCanvas->Write("Rhat");
714  delete Legend;
715  Legend = nullptr;
716 
717  //Now R hat for log L
718  RhatLogPlot->GetXaxis()->SetTitle("R hat for LogL");
719  RhatLogPlot->SetLineColor(kRed);
720  RhatLogPlot->SetFillColor(kRed);
721  RhatFoldedLogPlot->SetLineColor(kBlue);
722  RhatFoldedLogPlot->SetFillColor(kBlue);
723 
724  Legend = new TLegend(0.55, 0.6, 0.9, 0.9);
725  Legend->SetTextSize(0.04);
726  Legend->SetFillColor(0);
727  Legend->SetFillStyle(0);
728  Legend->SetLineWidth(0);
729  Legend->SetLineColor(0);
730 
731  Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
732  Legend->AddEntry(RhatLogPlot, "Rhat Gelman 2013", "l");
733  Legend->AddEntry(RhatFoldedLogPlot, "Rhat-Folded Gelman 2021", "l");
734 
735  RhatLogPlot->Draw();
736  RhatFoldedLogPlot->Draw("same");
737  Legend->Draw("same");
738  TempCanvas->Write("RhatLog");
739  delete Legend;
740  Legend = nullptr;
741 
742  //Now canvas for effective sample size
743  EffectiveSampleSizePlot->GetXaxis()->SetTitle("S_{eff, BDA2}");
744  EffectiveSampleSizePlot->SetLineColor(kRed);
745  EffectiveSampleSizeFoldedPlot->SetLineColor(kBlue);
746 
747  Legend = new TLegend(0.45, 0.6, 0.9, 0.9);
748  Legend->SetTextSize(0.03);
749  Legend->SetFillColor(0);
750  Legend->SetFillStyle(0);
751  Legend->SetLineWidth(0);
752  Legend->SetLineColor(0);
753 
754  const double Mean1 = EffectiveSampleSizePlot->GetMean();
755  const double RMS1 = EffectiveSampleSizePlot->GetRMS();
756  const double Mean2 = EffectiveSampleSizeFoldedPlot->GetMean();
757  const double RMS2 = EffectiveSampleSizeFoldedPlot->GetRMS();
758 
759  Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
760  Legend->AddEntry(EffectiveSampleSizePlot, Form("S_{eff, BDA2} #mu = %.2f, #sigma = %.2f",Mean1 ,RMS1), "l");
761  Legend->AddEntry(EffectiveSampleSizeFoldedPlot, Form("S_{eff, BDA2} Folded, #mu = %.2f, #sigma = %.2f",Mean2 ,RMS2), "l");
762 
763  EffectiveSampleSizePlot->Draw();
764  EffectiveSampleSizeFoldedPlot->Draw("same");
765  Legend->Draw("same");
766  TempCanvas->Write("EffectiveSampleSize");
767 
768  //Fancy memory cleaning
769  delete StandardDeviationGlobalPlot;
770  delete BetweenChainVariancePlot;
771  delete MarginalPosteriorVariancePlot;
772  delete RhatPlot;
773  delete EffectiveSampleSizePlot;
774 
775  delete StandardDeviationGlobalFoldedPlot;
776  delete BetweenChainVarianceFoldedPlot;
777  delete MarginalPosteriorVarianceFoldedPlot;
778  delete RhatFoldedPlot;
779  delete EffectiveSampleSizeFoldedPlot;
780 
781  delete Legend;
782 
783  delete RhatLogPlot;
784  delete RhatFoldedLogPlot;
785 
786  DiagFile->Close();
787  delete DiagFile;
788 
789  MACH3LOG_INFO("Finished and wrote results to {}", NameTemp);
790 }
791 
792 // *******************
793 //KS: Pseudo destructor
795 // *******************
796 
797  MACH3LOG_INFO("Killing all arrays");
798  delete[] MeanGlobal;
799  delete[] StandardDeviationGlobal;
800  delete[] BetweenChainVariance;
801  delete[] MarginalPosteriorVariance;
802  delete[] RHat;
803  delete[] EffectiveSampleSize;
804 
805  delete[] MeanGlobalFolded;
809  delete[] RHatFolded;
810  delete[] EffectiveSampleSizeFolded;
811 
812  for(int m = 0; m < Nchains; m++)
813  {
814  for(int i = 0; i < Ntoys; i++)
815  {
816  delete[] Draws[m][i];
817  delete[] DrawsFolded[m][i];
818  }
819  delete[] Draws[m];
820  delete[] Mean[m];
821  delete[] StandardDeviation[m];
822 
823  delete[] DrawsFolded[m];
824  delete[] MeanFolded[m];
825  delete[] StandardDeviationFolded[m];
826  }
827  delete[] Draws;
828  delete[] Mean;
829  delete[] StandardDeviation;
830 
831  delete[] DrawsFolded;
832  delete[] MedianArr;
833  delete[] MeanFolded;
834  delete[] StandardDeviationFolded;
835 }
836 
837 // *******************
838 //calculate median
839 double CalcMedian(double arr[], const int size) {
840 // *******************
841  std::sort(arr, arr+size);
842  if (size % 2 != 0)
843  return arr[size/2];
844  return (arr[(size-1)/2] + arr[size/2])/2.0;
845 }
846 
847 // *******************
848 //calculate median
849 void CapVariable(double var, const double cap) {
850 // *******************
851  if(std::isnan(var) || !std::isfinite(var)) var = cap;
852 }
#define _MaCh3_Safe_Include_Start_
KS: Avoiding warning checking for headers.
Definition: Core.h:109
#define _MaCh3_Safe_Include_End_
KS: Restore warning checking after including external headers.
Definition: Core.h:120
int size
#define MACH3LOG_CRITICAL
Definition: MaCh3Logger.h:28
#define MACH3LOG_DEBUG
Definition: MaCh3Logger.h:24
#define MACH3LOG_ERROR
Definition: MaCh3Logger.h:27
#define MACH3LOG_INFO
Definition: MaCh3Logger.h:25
void SetMaCh3LoggerFormat()
Set messaging format of the logger.
Definition: MaCh3Logger.h:51
#define MACH3LOG_WARN
Definition: MaCh3Logger.h:26
int main(int argc, char *argv[])
void SaveResults()
int Nchains
double * StandardDeviationGlobalFolded
void CapVariable(double var, double cap)
double * EffectiveSampleSizeFolded
double * BetweenChainVarianceFolded
void InitialiseArrays()
double * MeanGlobal
double ** MeanFolded
void RunDiagnostic()
std::vector< bool > ValidPar
double * BetweenChainVariance
double * MeanGlobalFolded
double ** StandardDeviation
double * StandardDeviationGlobal
double * MarginalPosteriorVariance
double * RHatFolded
double * MedianArr
double *** DrawsFolded
double ** StandardDeviationFolded
std::vector< TString > BranchNames
double *** Draws
std::vector< std::string > MCMCFile
double * RHat
double ** Mean
void DestroyArrays()
void CalcRhat()
double CalcMedian(double arr[], int size)
int nDraw
double * EffectiveSampleSize
void PrepareChains()
double * MarginalPosteriorVarianceFolded
int Ntoys
Custom exception class for MaCh3 errors.
void PrintProgressBar(const Long64_t Done, const Long64_t All)
KS: Simply print progress bar.
Definition: Monitor.cpp:213
void MaCh3Welcome()
KS: Prints welcome message with MaCh3 logo.
Definition: Monitor.cpp:12