MaCh3  2.4.2
Reference Guide
Functions | Variables
RHat_HighMem.cpp File Reference
#include "Manager/Manager.h"
#include "Samples/SampleStructs.h"
#include "Samples/HistogramUtils.h"
#include "TObjArray.h"
#include "TChain.h"
#include "TFile.h"
#include "TBranch.h"
#include "TCanvas.h"
#include "TLine.h"
#include "TLegend.h"
#include "TString.h"
#include "TH1.h"
#include "TRandom3.h"
#include "TStopwatch.h"
#include "TColor.h"
#include "TStyle.h"
#include "TROOT.h"
Include dependency graph for RHat_HighMem.cpp:

Go to the source code of this file.

Functions

void PrepareChains ()
 
void InitialiseArrays ()
 
void RunDiagnostic ()
 
void CalcRhat ()
 
void SaveResults ()
 
void DestroyArrays ()
 
double CalcMedian (double arr[], int size)
 
void CapVariable (double var, double cap)
 
int main (int argc, char *argv[])
 

Variables

int Ntoys
 
int Nchains
 
int nDraw
 
std::vector< TString > BranchNames
 
std::vector< std::string > MCMCFile
 
std::vector< bool > ValidPar
 
double *** Draws
 
double ** Mean
 
double ** StandardDeviation
 
double * MeanGlobal
 
double * StandardDeviationGlobal
 
double * BetweenChainVariance
 
double * MarginalPosteriorVariance
 
double * RHat
 
double * EffectiveSampleSize
 
double *** DrawsFolded
 
double * MedianArr
 
double ** MeanFolded
 
double ** StandardDeviationFolded
 
double * MeanGlobalFolded
 
double * StandardDeviationGlobalFolded
 
double * BetweenChainVarianceFolded
 
double * MarginalPosteriorVarianceFolded
 
double * RHatFolded
 
double * EffectiveSampleSizeFolded
 

Function Documentation

◆ CalcMedian()

double CalcMedian ( double  arr[],
int  size 
)

Definition at line 846 of file RHat_HighMem.cpp.

846  {
847 // *******************
848  std::sort(arr, arr+size);
849  if (size % 2 != 0)
850  return arr[size/2];
851  return (arr[(size-1)/2] + arr[size/2])/2.0;
852 }

◆ CalcRhat()

void CalcRhat ( )

Definition at line 442 of file RHat_HighMem.cpp.

442  {
443 // *******************
444 
445  TStopwatch clock;
446  clock.Start();
447 
448  //KS: Start parallel region
449  // If we would like to do this for thousands of chains we might consider using GPU for this
450  #ifdef MULTITHREAD
451  #pragma omp parallel
452  {
453  #endif
454 
455  #ifdef MULTITHREAD
456  #pragma omp for collapse(2)
457  #endif
458  //KS: loop over chains and draws are independent so might as well collapse for sweet cache hits
459  //Calculate the mean for each parameter within each considered chain
460  for (int m = 0; m < Nchains; ++m)
461  {
462  for (int j = 0; j < nDraw; ++j)
463  {
464  for(int i = 0; i < Ntoys; i++)
465  {
466  Mean[m][j] += Draws[m][i][j];
467  MeanFolded[m][j] += DrawsFolded[m][i][j];
468  }
469  Mean[m][j] = Mean[m][j]/Ntoys;
470  MeanFolded[m][j] = MeanFolded[m][j]/Ntoys;
471  }
472  }
473 
474  #ifdef MULTITHREAD
475  #pragma omp for
476  #endif
477  //Calculate the mean for each parameter global means we include information from several chains
478  for (int j = 0; j < nDraw; ++j)
479  {
480  for (int m = 0; m < Nchains; ++m)
481  {
482  MeanGlobal[j] += Mean[m][j];
483  MeanGlobalFolded[j] += MeanFolded[m][j];
484  }
485  MeanGlobal[j] = MeanGlobal[j]/Nchains;
487  }
488 
489 
490  #ifdef MULTITHREAD
491  #pragma omp for collapse(2)
492  #endif
493  //Calculate the standard deviation for each parameter within each considered chain
494  for (int m = 0; m < Nchains; ++m)
495  {
496  for (int j = 0; j < nDraw; ++j)
497  {
498  for(int i = 0; i < Ntoys; i++)
499  {
500  StandardDeviation[m][j] += (Draws[m][i][j] - Mean[m][j])*(Draws[m][i][j] - Mean[m][j]);
501  StandardDeviationFolded[m][j] += (DrawsFolded[m][i][j] - MeanFolded[m][j])*(DrawsFolded[m][i][j] - MeanFolded[m][j]);
502  }
503  StandardDeviation[m][j] = StandardDeviation[m][j]/(Ntoys-1);
505  }
506  }
507 
508  #ifdef MULTITHREAD
509  #pragma omp for
510  #endif
511  //Calculate the standard deviation for each parameter combining information from all chains
512  for (int j = 0; j < nDraw; ++j)
513  {
514  for (int m = 0; m < Nchains; ++m)
515  {
518  }
521  }
522 
523  #ifdef MULTITHREAD
524  #pragma omp for
525  #endif
526  for (int j = 0; j < nDraw; ++j)
527  {
528  //KS: This term only makes sense if we have at least 2 chains
529  if(Nchains == 1)
530  {
531  BetweenChainVariance[j] = 0.;
533  }
534  else
535  {
536  for (int m = 0; m < Nchains; ++m)
537  {
538  BetweenChainVariance[j] += ( Mean[m][j] - MeanGlobal[j])*( Mean[m][j] - MeanGlobal[j]);
540  }
543  }
544  }
545 
546  #ifdef MULTITHREAD
547  #pragma omp for
548  #endif
549  for (int j = 0; j < nDraw; ++j)
550  {
553  }
554 
555  #ifdef MULTITHREAD
556  #pragma omp for
557  #endif
558  //Finally calculate our estimator
559  for (int j = 0; j < nDraw; ++j)
560  {
563 
564  //KS: For flat params values can be crazy so cap at 0
565  CapVariable(RHat[j], 0);
566  CapVariable(RHatFolded[j], 0);
567  }
568 
569  #ifdef MULTITHREAD
570  #pragma omp for
571  #endif
572  //KS: Additionally calculates effective step size which is an estimate of the sample size required to achieve the same level of precision if that sample was a simple random sample.
573  for (int j = 0; j < nDraw; ++j)
574  {
577 
578  //KS: For flat params values can be crazy so cap at 0
581  }
582  #ifdef MULTITHREAD
583  } //End parallel region
584  #endif
585 
586  clock.Stop();
587  MACH3LOG_INFO("Finished calculating RHat, it took {:.2f}s to finish", clock.RealTime());
588 }
#define MACH3LOG_INFO
Definition: MaCh3Logger.h:35
int Nchains
double * StandardDeviationGlobalFolded
void CapVariable(double var, double cap)
double * EffectiveSampleSizeFolded
double * BetweenChainVarianceFolded
double * MeanGlobal
double ** MeanFolded
double * BetweenChainVariance
double * MeanGlobalFolded
double ** StandardDeviation
double * StandardDeviationGlobal
double * MarginalPosteriorVariance
double * RHatFolded
double *** DrawsFolded
double ** StandardDeviationFolded
double *** Draws
double * RHat
double ** Mean
int nDraw
double * EffectiveSampleSize
double * MarginalPosteriorVarianceFolded
int Ntoys

◆ CapVariable()

void CapVariable ( double  var,
double  cap 
)

Definition at line 856 of file RHat_HighMem.cpp.

856  {
857 // *******************
858  if(std::isnan(var) || !std::isfinite(var)) var = cap;
859 }

◆ DestroyArrays()

void DestroyArrays ( )

Definition at line 801 of file RHat_HighMem.cpp.

801  {
802 // *******************
803 
804  MACH3LOG_INFO("Killing all arrays");
805  delete[] MeanGlobal;
806  delete[] StandardDeviationGlobal;
807  delete[] BetweenChainVariance;
808  delete[] MarginalPosteriorVariance;
809  delete[] RHat;
810  delete[] EffectiveSampleSize;
811 
812  delete[] MeanGlobalFolded;
816  delete[] RHatFolded;
817  delete[] EffectiveSampleSizeFolded;
818 
819  for(int m = 0; m < Nchains; m++)
820  {
821  for(int i = 0; i < Ntoys; i++)
822  {
823  delete[] Draws[m][i];
824  delete[] DrawsFolded[m][i];
825  }
826  delete[] Draws[m];
827  delete[] Mean[m];
828  delete[] StandardDeviation[m];
829 
830  delete[] DrawsFolded[m];
831  delete[] MeanFolded[m];
832  delete[] StandardDeviationFolded[m];
833  }
834  delete[] Draws;
835  delete[] Mean;
836  delete[] StandardDeviation;
837 
838  delete[] DrawsFolded;
839  delete[] MedianArr;
840  delete[] MeanFolded;
841  delete[] StandardDeviationFolded;
842 }
double * MedianArr

◆ InitialiseArrays()

void InitialiseArrays ( )

Definition at line 372 of file RHat_HighMem.cpp.

372  {
373 // *******************
374 
375  MACH3LOG_INFO("Initialising arrays");
376  Mean = new double*[Nchains]();
377  StandardDeviation = new double*[Nchains]();
378 
379  MeanGlobal = new double[nDraw]();
380  StandardDeviationGlobal = new double[nDraw]();
381  BetweenChainVariance = new double[nDraw]();
382 
383  MarginalPosteriorVariance = new double[nDraw]();
384  RHat = new double[nDraw]();
385  EffectiveSampleSize = new double[nDraw]();
386 
387  MeanFolded = new double*[Nchains]();
388  StandardDeviationFolded = new double*[Nchains]();
389 
390  MeanGlobalFolded = new double[nDraw]();
391  StandardDeviationGlobalFolded = new double[nDraw]();
392  BetweenChainVarianceFolded = new double[nDraw]();
393 
394  MarginalPosteriorVarianceFolded = new double[nDraw]();
395  RHatFolded = new double[nDraw]();
396  EffectiveSampleSizeFolded = new double[nDraw]();
397 
398  for (int m = 0; m < Nchains; ++m)
399  {
400  Mean[m] = new double[nDraw]();
401  StandardDeviation[m] = new double[nDraw]();
402 
403  MeanFolded[m] = new double[nDraw]();
404  StandardDeviationFolded[m] = new double[nDraw]();
405  for (int j = 0; j < nDraw; ++j)
406  {
407  Mean[m][j] = 0.;
408  StandardDeviation[m][j] = 0.;
409 
410  MeanFolded[m][j] = 0.;
411  StandardDeviationFolded[m][j] = 0.;
412  if(m == 0)
413  {
414  MeanGlobal[j] = 0.;
415  StandardDeviationGlobal[j] = 0.;
416  BetweenChainVariance[j] = 0.;
418  RHat[j] = 0.;
419  EffectiveSampleSize[j] = 0.;
420 
421  MeanGlobalFolded[j] = 0.;
425  RHatFolded[j] = 0.;
427  }
428  }
429  }
430 }

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 88 of file RHat_HighMem.cpp.

88  {
89 // *******************
90 
93 
94  Draws = nullptr;
95  Mean = nullptr;
96  StandardDeviation = nullptr;
97 
98  MeanGlobal = nullptr;
99  StandardDeviationGlobal = nullptr;
100 
101  BetweenChainVariance = nullptr;
102  MarginalPosteriorVariance = nullptr;
103  RHat = nullptr;
104  EffectiveSampleSize = nullptr;
105 
106  DrawsFolded = nullptr;
107  MedianArr = nullptr;
108  MeanFolded = nullptr;
109  StandardDeviationFolded = nullptr;
110 
111  MeanGlobalFolded = nullptr;
113 
114  BetweenChainVarianceFolded = nullptr;
116  RHatFolded = nullptr;
117  EffectiveSampleSizeFolded = nullptr;
118 
119  Nchains = 0;
120 
121  if (argc == 1 || argc == 2)
122  {
123  MACH3LOG_ERROR("Wrong arguments");
124  MACH3LOG_ERROR("./RHat Ntoys MCMCchain_1.root MCMCchain_2.root MCMCchain_3.root ... [how many you like]");
125  throw MaCh3Exception(__FILE__ , __LINE__ );
126  }
127 
128  Ntoys = atoi(argv[1]);
129 
130  //KS Gelman suggests to diagnose on more than one chain
131  for (int i = 2; i < argc; i++)
132  {
133  MCMCFile.push_back(std::string(argv[i]));
134  MACH3LOG_INFO("Adding file: {}", MCMCFile.back());
135  Nchains++;
136  }
137 
138  if(Ntoys < 1)
139  {
140  MACH3LOG_ERROR("You specified {} specify larger greater than 0", Ntoys);
141  throw MaCh3Exception(__FILE__ , __LINE__ );
142  }
143 
144  if(Nchains == 1)
145  {
146  MACH3LOG_WARN("Gelman is going to be sad :(. He suggested you should use more than one chain (at least 4). Code works fine for one chain, however, estimator might be biased.");
147  MACH3LOG_WARN("Multiple chains are more likely to reveal multimodality and poor adaptation or mixing:");
148  }
149  MACH3LOG_INFO("Diagnosing {} chains, with {} toys", Nchains, Ntoys);
150 
151  PrepareChains();
152 
154 
155  //KS: Main function
156  RunDiagnostic();
157 
158  SaveResults();
159 
160  DestroyArrays();
161 
162  return 0;
163 }
#define MACH3LOG_ERROR
Definition: MaCh3Logger.h:37
void SetMaCh3LoggerFormat()
Set messaging format of the logger.
Definition: MaCh3Logger.h:61
#define MACH3LOG_WARN
Definition: MaCh3Logger.h:36
void SaveResults()
void InitialiseArrays()
void RunDiagnostic()
std::vector< std::string > MCMCFile
void DestroyArrays()
void PrepareChains()
Custom exception class used throughout MaCh3.
void MaCh3Welcome()
KS: Prints welcome message with MaCh3 logo.
Definition: Monitor.cpp:12

◆ PrepareChains()

void PrepareChains ( )

Definition at line 167 of file RHat_HighMem.cpp.

167  {
168 // *******************
169  auto rnd = std::make_unique<TRandom3>(0);
170 
171  MACH3LOG_INFO("Generating {}", Ntoys);
172 
173  TStopwatch clock;
174  clock.Start();
175 
176  std::vector<unsigned int> BurnIn(Nchains);
177  std::vector<unsigned int> nEntries(Nchains);
178  std::vector<int> nBranches(Nchains);
179  std::vector<unsigned int> step(Nchains);
180 
181  Draws = new double**[Nchains]();
182  DrawsFolded = new double**[Nchains]();
183 
184  // KS: This can reduce time necessary for caching even by half
185  #ifdef MULTITHREAD
186  //ROOT::EnableImplicitMT();
187  #endif
188 
189  // Open the Chain
190  //It is tempting to multithread here but unfortunately, ROOT files are not thread safe :(
191  for (int m = 0; m < Nchains; m++)
192  {
193  TChain* Chain = new TChain("posteriors");
194  Chain->Add(MCMCFile[m].c_str());
195  MACH3LOG_INFO("On file: {}", MCMCFile[m].c_str());
196  nEntries[m] = static_cast<unsigned int>(Chain->GetEntries());
197 
198  // Set the step cut to be 20%
199  BurnIn[m] = nEntries[m]/5;
200 
201  // Get the list of branches
202  TObjArray* brlis = Chain->GetListOfBranches();
203 
204  // Get the number of branches
205  nBranches[m] = brlis->GetEntries();
206 
207  if(m == 0) BranchNames.reserve(nBranches[m]);
208 
209  // Set all the branches to off
210  Chain->SetBranchStatus("*", false);
211 
212  // Loop over the number of branches
213  // Find the name and how many of each systematic we have
214  for (int i = 0; i < nBranches[m]; i++)
215  {
216  // Get the TBranch and its name
217  TBranch* br = static_cast<TBranch *>(brlis->At(i));
218  if(!br){
219  MACH3LOG_ERROR("Invalid branch at position {}", i);
220  throw MaCh3Exception(__FILE__,__LINE__);
221  }
222  TString bname = br->GetName();
223 
224  // Read in the step
225  if (bname == "step") {
226  Chain->SetBranchStatus(bname, true);
227  Chain->SetBranchAddress(bname, &step[m]);
228  }
229  //Count all branches
230  else if (bname.BeginsWith("PCA_") || bname.BeginsWith("accProb") || bname.BeginsWith("stepTime") )
231  {
232  continue;
233  }
234  else
235  {
236  //KS: Save branch name only for one chain, we assume all chains have the same branches, otherwise this doesn't make sense either way
237  if(m == 0)
238  {
239  BranchNames.push_back(bname);
240  //KS: We calculate R Hat also for LogL, just in case, however we plot them separately
241  if(bname.BeginsWith("LogL"))
242  {
243  ValidPar.push_back(false);
244  }
245  else
246  {
247  ValidPar.push_back(true);
248  }
249  }
250  Chain->SetBranchStatus(bname, true);
251  MACH3LOG_DEBUG("{}", bname);
252  }
253  }
254 
255  if(m == 0) nDraw = int(BranchNames.size());
256 
257  //TN: Qualitatively faster sanity check, with the very same outcome (all chains have the same #branches)
258  if(m > 0)
259  {
260  if(nBranches[m] != nBranches[0])
261  {
262  MACH3LOG_ERROR("Ups, something went wrong, chain {} called {} has {} branches, while 0 called {} has {} branches", m, MCMCFile[m], nBranches[m], MCMCFile[0], nBranches[0]);
263  MACH3LOG_ERROR("All chains should have the same number of branches");
264  throw MaCh3Exception(__FILE__ , __LINE__ );
265  }
266  }
267 
268  //TN: move the Draws here, so we need to iterate over every chain only once
269  Draws[m] = new double*[Ntoys]();
270  DrawsFolded[m] = new double*[Ntoys]();
271  for(int i = 0; i < Ntoys; i++)
272  {
273  Draws[m][i] = new double[nDraw]();
274  DrawsFolded[m][i] = new double[nDraw]();
275  for(int j = 0; j < nDraw; j++)
276  {
277  Draws[m][i][j] = 0.;
278  DrawsFolded[m][i][j] = 0.;
279  }
280  }
281 
282  // MJR: array to hold branch values; SetBranchAddress in every step is very
283  // expensive, so doing it once only here saves time
284  double* branch_values = new double[nDraw]();
285  for (int j = 0; j < nDraw; ++j)
286  {
287  Chain->SetBranchAddress(BranchNames[j].Data(), &branch_values[j]);
288  }
289 
290  //TN: move looping over toys here, so we don't need to loop over chains more than once
291  if(BurnIn[m] >= nEntries[m])
292  {
293  MACH3LOG_ERROR("You are running on a chain shorter than BurnIn cut");
294  MACH3LOG_ERROR("Number of entries {} BurnIn cut {}", nEntries[m], BurnIn[m]);
295  MACH3LOG_ERROR("You will run into the infinite loop");
296  MACH3LOG_ERROR("You can make a new chain or modify BurnIn cut");
297  throw MaCh3Exception(__FILE__ , __LINE__ );
298  }
299 
300  for (int i = 0; i < Ntoys; i++)
301  {
302  // Get a random entry after burn in
303  int entry = int(nEntries[m]*rnd->Rndm());
304 
305  Chain->GetEntry(entry);
306 
307  // If we have combined chains by hadd need to check the step in the chain
308  // Note, entry is not necessarily the same as the step due to merged ROOT files, so can't choose an entry in the range BurnIn - nEntries :(
309  if (step[m] < BurnIn[m])
310  {
311  i--;
312  continue;
313  }
314 
315  // Output some info for the user
316  if (Ntoys > 10 && i % (Ntoys/10) == 0) {
317  MaCh3Utils::PrintProgressBar(i+m*Ntoys, static_cast<Long64_t>(Ntoys)*Nchains);
318  MACH3LOG_DEBUG("Getting random entry {}", entry);
319  }
320 
321  // Set the branch addresses for params
322  for (int j = 0; j < nDraw; ++j) {
323  Draws[m][i][j] = branch_values[j];
324  }
325  }//end loop over toys
326 
327  //TN: There, we now don't need to keep the chain in memory anymore
328  delete Chain;
329  delete[] branch_values;
330  }
331 
332  //KS: Now prepare folded draws, quoting Gelman
333  //"We propose to report the maximum of rank normalized split-Rb and rank normalized folded-split-Rb for each parameter"
334  MedianArr = new double[nDraw]();
335  #ifdef MULTITHREAD
336  #pragma omp parallel for
337  #endif
338  for(int j = 0; j < nDraw; j++)
339  {
340  MedianArr[j] = 0.;
341  std::vector<double> TempDraws(static_cast<size_t>(Ntoys) * Nchains);
342  for(int m = 0; m < Nchains; m++)
343  {
344  for(int i = 0; i < Ntoys; i++)
345  {
346  const int im = i+m;
347  TempDraws[im] = Draws[m][i][j];
348  }
349  }
350  MedianArr[j] = CalcMedian(TempDraws.data(), Ntoys*Nchains);
351  }
352 
353  #ifdef MULTITHREAD
354  #pragma omp parallel for collapse(3)
355  #endif
356  for(int m = 0; m < Nchains; m++)
357  {
358  for(int i = 0; i < Ntoys; i++)
359  {
360  for(int j = 0; j < nDraw; j++)
361  {
362  DrawsFolded[m][i][j] = std::fabs(Draws[m][i][j] - MedianArr[j]);
363  }
364  }
365  }
366  clock.Stop();
367  MACH3LOG_INFO("Finished calculating Toys, it took {:.2f}s to finish", clock.RealTime());
368 }
#define MACH3LOG_DEBUG
Definition: MaCh3Logger.h:34
std::vector< bool > ValidPar
std::vector< TString > BranchNames
double CalcMedian(double arr[], int size)
void PrintProgressBar(const Long64_t Done, const Long64_t All)
KS: Simply print progress bar.
Definition: Monitor.cpp:228

◆ RunDiagnostic()

void RunDiagnostic ( )

Definition at line 433 of file RHat_HighMem.cpp.

433  {
434 // *******************
435  CalcRhat();
436  //In case in future we expand this
437 }
void CalcRhat()

◆ SaveResults()

void SaveResults ( )

Definition at line 592 of file RHat_HighMem.cpp.

592  {
593 // *******************
594  #pragma GCC diagnostic ignored "-Wfloat-conversion"
595 
596  std::string NameTemp = "";
597  //KS: If we run over many many chains there is danger that name will be so absurdly long we run over system limit and job will be killed :(
598  if(Nchains < 5)
599  {
600  for (int i = 0; i < Nchains; i++)
601  {
602  std::string temp = MCMCFile[i];
603 
604  while (temp.find(".root") != std::string::npos) {
605  temp = temp.substr(0, temp.find(".root"));
606  }
607  // Strip directory path
608  const auto slash = temp.find_last_of("/\\");
609  if (slash != std::string::npos) {
610  temp = temp.substr(slash + 1);
611  }
612  NameTemp = NameTemp + temp + "_";
613  }
614  }
615  else {
616  NameTemp = std::to_string(Nchains) + "Chains" + "_";
617  }
618  NameTemp += "diag.root";
619 
620  TFile *DiagFile = M3::Open(NameTemp, "recreate", __FILE__, __LINE__);
621  DiagFile->cd();
622 
623  TH1D *StandardDeviationGlobalPlot = new TH1D("StandardDeviationGlobalPlot", "StandardDeviationGlobalPlot", nDraw, 0, nDraw);
624  TH1D *BetweenChainVariancePlot = new TH1D("BetweenChainVariancePlot", "BetweenChainVariancePlot", nDraw, 0, nDraw);
625  TH1D *MarginalPosteriorVariancePlot = new TH1D("MarginalPosteriorVariancePlot", "MarginalPosteriorVariancePlot", nDraw, 0, nDraw);
626  TH1D *RhatPlot = new TH1D("RhatPlot", "RhatPlot", 200, 0, 2);
627  TH1D *EffectiveSampleSizePlot = new TH1D("EffectiveSampleSizePlot", "EffectiveSampleSizePlot", 400, 0, 10000);
628 
629  TH1D *StandardDeviationGlobalFoldedPlot = new TH1D("StandardDeviationGlobalFoldedPlot", "StandardDeviationGlobalFoldedPlot", nDraw, 0, nDraw);
630  TH1D *BetweenChainVarianceFoldedPlot = new TH1D("BetweenChainVarianceFoldedPlot", "BetweenChainVarianceFoldedPlot", nDraw, 0, nDraw);
631  TH1D *MarginalPosteriorVarianceFoldedPlot = new TH1D("MarginalPosteriorVarianceFoldedPlot", "MarginalPosteriorVarianceFoldedPlot", nDraw, 0, nDraw);
632  TH1D *RhatFoldedPlot = new TH1D("RhatFoldedPlot", "RhatFoldedPlot", 200, 0, 2);
633  TH1D *EffectiveSampleSizeFoldedPlot = new TH1D("EffectiveSampleSizeFoldedPlot", "EffectiveSampleSizeFoldedPlot", 400, 0, 10000);
634 
635  TH1D *RhatLogPlot = new TH1D("RhatLogPlot", "RhatLogPlot", 200, 0, 2);
636  TH1D *RhatFoldedLogPlot = new TH1D("RhatFoldedLogPlot", "RhatFoldedLogPlot", 200, 0, 2);
637 
638  int Criterium = 0;
639  int CiteriumFolded = 0;
640  for(int j = 0; j < nDraw; j++)
641  {
642  //KS: Fill only valid parameters
643  if(ValidPar[j])
644  {
645  StandardDeviationGlobalPlot->Fill(j,StandardDeviationGlobal[j]);
646  BetweenChainVariancePlot->Fill(j,BetweenChainVariance[j]);
647  MarginalPosteriorVariancePlot->Fill(j,MarginalPosteriorVariance[j]);
648  RhatPlot->Fill(RHat[j]);
649  EffectiveSampleSizePlot->Fill(EffectiveSampleSize[j]);
650  if(RHat[j] > 1.1) Criterium++;
651 
652 
653  StandardDeviationGlobalFoldedPlot->Fill(j,StandardDeviationGlobalFolded[j]);
654  BetweenChainVarianceFoldedPlot->Fill(j,BetweenChainVarianceFolded[j]);
655  MarginalPosteriorVarianceFoldedPlot->Fill(j,MarginalPosteriorVarianceFolded[j]);
656  RhatFoldedPlot->Fill(RHatFolded[j]);
657  EffectiveSampleSizeFoldedPlot->Fill(EffectiveSampleSizeFolded[j]);
658  if(RHatFolded[j] > 1.1) CiteriumFolded++;
659  }
660  else
661  {
662  RhatLogPlot->Fill(RHat[j]);
663  RhatFoldedLogPlot->Fill(RHatFolded[j]);
664  }
665  }
666  //KS: We set criterium of 1.1 based on Gelman et al. (2003) Bayesian Data Analysis
667  MACH3LOG_WARN("Number of parameters which has R hat greater than 1.1 is {}({:.2f}%) while for R hat folded {}({:.2f}%)", Criterium, 100*double(Criterium)/double(nDraw), CiteriumFolded, 100*double(CiteriumFolded)/double(nDraw));
668  for(int j = 0; j < nDraw; j++)
669  {
670  if( (RHat[j] > 1.1 || RHatFolded[j] > 1.1) && ValidPar[j])
671  {
672  MACH3LOG_CRITICAL("Parameter {} has R hat higher than 1.1", BranchNames[j]);
673  }
674  }
675  StandardDeviationGlobalPlot->Write();
676  BetweenChainVariancePlot->Write();
677  MarginalPosteriorVariancePlot->Write();
678  RhatPlot->Write();
679  EffectiveSampleSizePlot->Write();
680 
681  StandardDeviationGlobalFoldedPlot->Write();
682  BetweenChainVarianceFoldedPlot->Write();
683  MarginalPosteriorVarianceFoldedPlot->Write();
684  RhatFoldedPlot->Write();
685  EffectiveSampleSizeFoldedPlot->Write();
686 
687  RhatLogPlot->Write();
688  RhatFoldedLogPlot->Write();
689 
690  //KS: Now we make fancy canvases, consider some function to have less copy pasting
691  auto TempCanvas = std::make_unique<TCanvas>("Canvas", "Canvas", 1024, 1024);
692  gStyle->SetOptStat(0);
693  TempCanvas->SetGridx();
694  TempCanvas->SetGridy();
695 
696  // Random line to write useful information to TLegend
697  auto TempLine = std::make_unique<TLine>(0, 0, 0, 0);
698  TempLine->SetLineColor(kBlack);
699 
700  RhatPlot->GetXaxis()->SetTitle("R hat");
701  RhatPlot->SetLineColor(kRed);
702  RhatPlot->SetFillColor(kRed);
703  RhatFoldedPlot->SetLineColor(kBlue);
704  RhatFoldedPlot->SetFillColor(kBlue);
705 
706  TLegend *Legend = new TLegend(0.55, 0.6, 0.9, 0.9);
707  Legend->SetTextSize(0.04);
708  Legend->SetFillColor(0);
709  Legend->SetFillStyle(0);
710  Legend->SetLineWidth(0);
711  Legend->SetLineColor(0);
712 
713  Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
714  Legend->AddEntry(RhatPlot, "Rhat Gelman 2013", "l");
715  Legend->AddEntry(RhatFoldedPlot, "Rhat-Folded Gelman 2021", "l");
716 
717  RhatPlot->Draw();
718  RhatFoldedPlot->Draw("same");
719  Legend->Draw("same");
720  TempCanvas->Write("Rhat");
721  delete Legend;
722  Legend = nullptr;
723 
724  //Now R hat for log L
725  RhatLogPlot->GetXaxis()->SetTitle("R hat for LogL");
726  RhatLogPlot->SetLineColor(kRed);
727  RhatLogPlot->SetFillColor(kRed);
728  RhatFoldedLogPlot->SetLineColor(kBlue);
729  RhatFoldedLogPlot->SetFillColor(kBlue);
730 
731  Legend = new TLegend(0.55, 0.6, 0.9, 0.9);
732  Legend->SetTextSize(0.04);
733  Legend->SetFillColor(0);
734  Legend->SetFillStyle(0);
735  Legend->SetLineWidth(0);
736  Legend->SetLineColor(0);
737 
738  Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
739  Legend->AddEntry(RhatLogPlot, "Rhat Gelman 2013", "l");
740  Legend->AddEntry(RhatFoldedLogPlot, "Rhat-Folded Gelman 2021", "l");
741 
742  RhatLogPlot->Draw();
743  RhatFoldedLogPlot->Draw("same");
744  Legend->Draw("same");
745  TempCanvas->Write("RhatLog");
746  delete Legend;
747  Legend = nullptr;
748 
749  //Now canvas for effective sample size
750  EffectiveSampleSizePlot->GetXaxis()->SetTitle("S_{eff, BDA2}");
751  EffectiveSampleSizePlot->SetLineColor(kRed);
752  EffectiveSampleSizeFoldedPlot->SetLineColor(kBlue);
753 
754  Legend = new TLegend(0.45, 0.6, 0.9, 0.9);
755  Legend->SetTextSize(0.03);
756  Legend->SetFillColor(0);
757  Legend->SetFillStyle(0);
758  Legend->SetLineWidth(0);
759  Legend->SetLineColor(0);
760 
761  const double Mean1 = EffectiveSampleSizePlot->GetMean();
762  const double RMS1 = EffectiveSampleSizePlot->GetRMS();
763  const double Mean2 = EffectiveSampleSizeFoldedPlot->GetMean();
764  const double RMS2 = EffectiveSampleSizeFoldedPlot->GetRMS();
765 
766  Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
767  Legend->AddEntry(EffectiveSampleSizePlot, Form("S_{eff, BDA2} #mu = %.2f, #sigma = %.2f",Mean1 ,RMS1), "l");
768  Legend->AddEntry(EffectiveSampleSizeFoldedPlot, Form("S_{eff, BDA2} Folded, #mu = %.2f, #sigma = %.2f",Mean2 ,RMS2), "l");
769 
770  EffectiveSampleSizePlot->Draw();
771  EffectiveSampleSizeFoldedPlot->Draw("same");
772  Legend->Draw("same");
773  TempCanvas->Write("EffectiveSampleSize");
774 
775  //Fancy memory cleaning
776  delete StandardDeviationGlobalPlot;
777  delete BetweenChainVariancePlot;
778  delete MarginalPosteriorVariancePlot;
779  delete RhatPlot;
780  delete EffectiveSampleSizePlot;
781 
782  delete StandardDeviationGlobalFoldedPlot;
783  delete BetweenChainVarianceFoldedPlot;
784  delete MarginalPosteriorVarianceFoldedPlot;
785  delete RhatFoldedPlot;
786  delete EffectiveSampleSizeFoldedPlot;
787 
788  delete Legend;
789 
790  delete RhatLogPlot;
791  delete RhatFoldedLogPlot;
792 
793  DiagFile->Close();
794  delete DiagFile;
795 
796  MACH3LOG_INFO("Finished and wrote results to {}", NameTemp);
797 }
#define MACH3LOG_CRITICAL
Definition: MaCh3Logger.h:38
TFile * Open(const std::string &Name, const std::string &Type, const std::string &File, const int Line)
Opens a ROOT file with the given name and mode.

Variable Documentation

◆ BetweenChainVariance

double* BetweenChainVariance

Definition at line 56 of file RHat_HighMem.cpp.

◆ BetweenChainVarianceFolded

double* BetweenChainVarianceFolded

Definition at line 70 of file RHat_HighMem.cpp.

◆ BranchNames

std::vector<TString> BranchNames

Definition at line 44 of file RHat_HighMem.cpp.

◆ Draws

double*** Draws

Definition at line 48 of file RHat_HighMem.cpp.

◆ DrawsFolded

double*** DrawsFolded

Definition at line 61 of file RHat_HighMem.cpp.

◆ EffectiveSampleSize

double* EffectiveSampleSize

Definition at line 59 of file RHat_HighMem.cpp.

◆ EffectiveSampleSizeFolded

double* EffectiveSampleSizeFolded

Definition at line 73 of file RHat_HighMem.cpp.

◆ MarginalPosteriorVariance

double* MarginalPosteriorVariance

Definition at line 57 of file RHat_HighMem.cpp.

◆ MarginalPosteriorVarianceFolded

double* MarginalPosteriorVarianceFolded

Definition at line 71 of file RHat_HighMem.cpp.

◆ MCMCFile

std::vector<std::string> MCMCFile

Definition at line 45 of file RHat_HighMem.cpp.

◆ Mean

double** Mean

Definition at line 50 of file RHat_HighMem.cpp.

◆ MeanFolded

double** MeanFolded

Definition at line 64 of file RHat_HighMem.cpp.

◆ MeanGlobal

double* MeanGlobal

Definition at line 53 of file RHat_HighMem.cpp.

◆ MeanGlobalFolded

double* MeanGlobalFolded

Definition at line 67 of file RHat_HighMem.cpp.

◆ MedianArr

double* MedianArr

Definition at line 62 of file RHat_HighMem.cpp.

◆ Nchains

int Nchains

Definition at line 40 of file RHat_HighMem.cpp.

◆ nDraw

int nDraw

Definition at line 42 of file RHat_HighMem.cpp.

◆ Ntoys

int Ntoys

Definition at line 39 of file RHat_HighMem.cpp.

◆ RHat

double* RHat

Definition at line 58 of file RHat_HighMem.cpp.

◆ RHatFolded

double* RHatFolded

Definition at line 72 of file RHat_HighMem.cpp.

◆ StandardDeviation

double** StandardDeviation

Definition at line 51 of file RHat_HighMem.cpp.

◆ StandardDeviationFolded

double** StandardDeviationFolded

Definition at line 65 of file RHat_HighMem.cpp.

◆ StandardDeviationGlobal

double* StandardDeviationGlobal

Definition at line 54 of file RHat_HighMem.cpp.

◆ StandardDeviationGlobalFolded

double* StandardDeviationGlobalFolded

Definition at line 68 of file RHat_HighMem.cpp.

◆ ValidPar

std::vector<bool> ValidPar

Definition at line 46 of file RHat_HighMem.cpp.