MaCh3 2.2.1
Reference Guide
Loading...
Searching...
No Matches
RHat_HighMem.cpp
Go to the documentation of this file.
1// MaCh3 includes
2#include "Manager/Manager.h"
4
6// ROOT includes
7#include "TObjArray.h"
8#include "TChain.h"
9#include "TFile.h"
10#include "TBranch.h"
11#include "TCanvas.h"
12#include "TLine.h"
13#include "TLegend.h"
14#include "TString.h"
15#include "TH1.h"
16#include "TRandom3.h"
17#include "TStopwatch.h"
18#include "TColor.h"
19#include "TStyle.h"
20#include "TROOT.h"
22
31
32// *******************
35
37
38std::vector<TString> BranchNames;
39std::vector<std::string> MCMCFile;
40std::vector<bool> ValidPar;
41
42double ***Draws;
43
44double** Mean;
46
47double* MeanGlobal;
49
52double* RHat;
54
55double ***DrawsFolded;
56double* MedianArr;
57
58double** MeanFolded;
60
63
66double* RHatFolded;
68// *******************
69void PrepareChains();
70void InitialiseArrays();
71
72void RunDiagnostic();
73void CalcRhat();
74
75void SaveResults();
76void DestroyArrays();
77double CalcMedian(double arr[], int size);
78
79void CapVariable(double var, double cap);
80
81// *******************
82int main(int argc, char *argv[]) {
83// *******************
84
87
88 Draws = nullptr;
89 Mean = nullptr;
90 StandardDeviation = nullptr;
91
92 MeanGlobal = nullptr;
94
95 BetweenChainVariance = nullptr;
97 RHat = nullptr;
98 EffectiveSampleSize = nullptr;
99
100 DrawsFolded = nullptr;
101 MedianArr = nullptr;
102 MeanFolded = nullptr;
103 StandardDeviationFolded = nullptr;
104
105 MeanGlobalFolded = nullptr;
107
110 RHatFolded = nullptr;
112
113 Nchains = 0;
114
115 if (argc == 1 || argc == 2)
116 {
117 MACH3LOG_ERROR("Wrong arguments");
118 MACH3LOG_ERROR("./RHat Ntoys MCMCchain_1.root MCMCchain_2.root MCMCchain_3.root ... [how many you like]");
119 throw MaCh3Exception(__FILE__ , __LINE__ );
120 }
121
122 Ntoys = atoi(argv[1]);
123
124 //KS Gelman suggests to diagnose on more than one chain
125 for (int i = 2; i < argc; i++)
126 {
127 MCMCFile.push_back(std::string(argv[i]));
128 MACH3LOG_INFO("Adding file: {}", MCMCFile.back());
129 Nchains++;
130 }
131
132 if(Ntoys < 1)
133 {
134 MACH3LOG_ERROR("You specified {} specify larger greater than 0", Ntoys);
135 throw MaCh3Exception(__FILE__ , __LINE__ );
136 }
137
138 if(Nchains == 1)
139 {
140 MACH3LOG_WARN("Gelman is going to be sad :(. He suggested you should use more than one chain (at least 4). Code works fine for one chain, however, estimator might be biased.");
141 MACH3LOG_WARN("Multiple chains are more likely to reveal multimodality and poor adaptation or mixing:");
142 }
143 MACH3LOG_INFO("Diagnosing {} chains, with {} toys", Nchains, Ntoys);
144
146
148
149 //KS: Main function
151
152 SaveResults();
153
155
156 return 0;
157}
158
159// *******************
160// Load chain and prepare toys
162// *******************
163 auto rnd = std::make_unique<TRandom3>(0);
164
165 MACH3LOG_INFO("Generating {}", Ntoys);
166
167 TStopwatch clock;
168 clock.Start();
169
170 std::vector<int> BurnIn(Nchains);
171 std::vector<int> nEntries(Nchains);
172 std::vector<int> nBranches(Nchains);
173 std::vector<int> step(Nchains);
174
175 Draws = new double**[Nchains]();
176 DrawsFolded = new double**[Nchains]();
177
178 // KS: This can reduce time necessary for caching even by half
179 #ifdef MULTITHREAD
180 //ROOT::EnableImplicitMT();
181 #endif
182
183 // Open the Chain
184 //It is tempting to multithread here but unfortunately, ROOT files are not thread safe :(
185 for (int m = 0; m < Nchains; m++)
186 {
187 TChain* Chain = new TChain("posteriors");
188 Chain->Add(MCMCFile[m].c_str());
189 MACH3LOG_INFO("On file: {}", MCMCFile[m].c_str());
190 nEntries[m] = int(Chain->GetEntries());
191
192 // Set the step cut to be 20%
193 BurnIn[m] = nEntries[m]/5;
194
195 // Get the list of branches
196 TObjArray* brlis = Chain->GetListOfBranches();
197
198 // Get the number of branches
199 nBranches[m] = brlis->GetEntries();
200
201 if(m == 0) BranchNames.reserve(nBranches[m]);
202
203 // Set all the branches to off
204 Chain->SetBranchStatus("*", false);
205
206 // Loop over the number of branches
207 // Find the name and how many of each systematic we have
208 for (int i = 0; i < nBranches[m]; i++)
209 {
210 // Get the TBranch and its name
211 TBranch* br = static_cast<TBranch *>(brlis->At(i));
212 if(!br){
213 MACH3LOG_ERROR("Invalid branch at position {}", i);
214 throw MaCh3Exception(__FILE__,__LINE__);
215 }
216 TString bname = br->GetName();
217
218 // Read in the step
219 if (bname == "step") {
220 Chain->SetBranchStatus(bname, true);
221 Chain->SetBranchAddress(bname, &step[m]);
222 }
223 //Count all branches
224 else if (bname.BeginsWith("PCA_") || bname.BeginsWith("accProb") || bname.BeginsWith("stepTime") )
225 {
226 continue;
227 }
228 else
229 {
230 //KS: Save branch name only for one chain, we assume all chains have the same branches, otherwise this doesn't make sense either way
231 if(m == 0)
232 {
233 BranchNames.push_back(bname);
234 //KS: We calculate R Hat also for LogL, just in case, however we plot them separately
235 if(bname.BeginsWith("LogL"))
236 {
237 ValidPar.push_back(false);
238 }
239 else
240 {
241 ValidPar.push_back(true);
242 }
243 }
244 Chain->SetBranchStatus(bname, true);
245 MACH3LOG_DEBUG("{}", bname);
246 }
247 }
248
249 if(m == 0) nDraw = int(BranchNames.size());
250
251 //TN: Qualitatively faster sanity check, with the very same outcome (all chains have the same #branches)
252 if(m > 0)
253 {
254 if(nBranches[m] != nBranches[0])
255 {
256 MACH3LOG_ERROR("Ups, something went wrong, chain {} called {} has {} branches, while 0 called {} has {} branches", m, MCMCFile[m], nBranches[m], MCMCFile[0], nBranches[0]);
257 MACH3LOG_ERROR("All chains should have the same number of branches");
258 throw MaCh3Exception(__FILE__ , __LINE__ );
259 }
260 }
261
262 //TN: move the Draws here, so we need to iterate over every chain only once
263 Draws[m] = new double*[Ntoys]();
264 DrawsFolded[m] = new double*[Ntoys]();
265 for(int i = 0; i < Ntoys; i++)
266 {
267 Draws[m][i] = new double[nDraw]();
268 DrawsFolded[m][i] = new double[nDraw]();
269 for(int j = 0; j < nDraw; j++)
270 {
271 Draws[m][i][j] = 0.;
272 DrawsFolded[m][i][j] = 0.;
273 }
274 }
275
276 // MJR: array to hold branch values; SetBranchAddress in every step is very
277 // expensive, so doing it once only here saves time
278 double* branch_values = new double[nDraw]();
279 for (int j = 0; j < nDraw; ++j)
280 {
281 Chain->SetBranchAddress(BranchNames[j].Data(), &branch_values[j]);
282 }
283
284 //TN: move looping over toys here, so we don't need to loop over chains more than once
285 if(BurnIn[m] >= nEntries[m])
286 {
287 MACH3LOG_ERROR("You are running on a chain shorter than BurnIn cut");
288 MACH3LOG_ERROR("Number of entries {} BurnIn cut {}", nEntries[m], BurnIn[m]);
289 MACH3LOG_ERROR("You will run into the infinite loop");
290 MACH3LOG_ERROR("You can make a new chain or modify BurnIn cut");
291 throw MaCh3Exception(__FILE__ , __LINE__ );
292 }
293
294 for (int i = 0; i < Ntoys; i++)
295 {
296 // Get a random entry after burn in
297 int entry = int(nEntries[m]*rnd->Rndm());
298
299 Chain->GetEntry(entry);
300
301 // If we have combined chains by hadd need to check the step in the chain
302 // Note, entry is not necessarily the same as the step due to merged ROOT files, so can't choose an entry in the range BurnIn - nEntries :(
303 if (step[m] < BurnIn[m])
304 {
305 i--;
306 continue;
307 }
308
309 // Output some info for the user
310 if (Ntoys > 10 && i % (Ntoys/10) == 0) {
311 MaCh3Utils::PrintProgressBar(i+m*Ntoys, static_cast<Long64_t>(Ntoys)*Nchains);
312 MACH3LOG_DEBUG("Getting random entry {}", entry);
313 }
314
315 // Set the branch addresses for params
316 for (int j = 0; j < nDraw; ++j)
317 {
318 Draws[m][i][j] = branch_values[j];
319 }
320
321 }//end loop over toys
322
323 //TN: There, we now don't need to keep the chain in memory anymore
324 delete Chain;
325 delete[] branch_values;
326 }
327
328 //KS: Now prepare folded draws, quoting Gelman
329 //"We propose to report the maximum of rank normalized split-Rb and rank normalized folded-split-Rb for each parameter"
330 MedianArr = new double[nDraw]();
331 #ifdef MULTITHREAD
332 #pragma omp parallel for
333 #endif
334 for(int j = 0; j < nDraw; j++)
335 {
336 MedianArr[j] = 0.;
337 std::vector<double> TempDraws(static_cast<size_t>(Ntoys) * Nchains);
338 for(int m = 0; m < Nchains; m++)
339 {
340 for(int i = 0; i < Ntoys; i++)
341 {
342 const int im = i+m;
343 TempDraws[im] = Draws[m][i][j];
344 }
345 }
346 MedianArr[j] = CalcMedian(TempDraws.data(), Ntoys*Nchains);
347 }
348
349 #ifdef MULTITHREAD
350 #pragma omp parallel for collapse(3)
351 #endif
352 for(int m = 0; m < Nchains; m++)
353 {
354 for(int i = 0; i < Ntoys; i++)
355 {
356 for(int j = 0; j < nDraw; j++)
357 {
358 DrawsFolded[m][i][j] = std::fabs(Draws[m][i][j] - MedianArr[j]);
359 }
360 }
361 }
362 clock.Stop();
363 MACH3LOG_INFO("Finished calculating Toys, it took {:.2f}s to finish", clock.RealTime());
364}
365
366// *******************
367// Create all arrays we are going to use later
369// *******************
370
371 MACH3LOG_INFO("Initialising arrays");
372 Mean = new double*[Nchains]();
373 StandardDeviation = new double*[Nchains]();
374
375 MeanGlobal = new double[nDraw]();
376 StandardDeviationGlobal = new double[nDraw]();
377 BetweenChainVariance = new double[nDraw]();
378
379 MarginalPosteriorVariance = new double[nDraw]();
380 RHat = new double[nDraw]();
381 EffectiveSampleSize = new double[nDraw]();
382
383 MeanFolded = new double*[Nchains]();
384 StandardDeviationFolded = new double*[Nchains]();
385
386 MeanGlobalFolded = new double[nDraw]();
387 StandardDeviationGlobalFolded = new double[nDraw]();
388 BetweenChainVarianceFolded = new double[nDraw]();
389
391 RHatFolded = new double[nDraw]();
392 EffectiveSampleSizeFolded = new double[nDraw]();
393
394 for (int m = 0; m < Nchains; ++m)
395 {
396 Mean[m] = new double[nDraw]();
397 StandardDeviation[m] = new double[nDraw]();
398
399 MeanFolded[m] = new double[nDraw]();
400 StandardDeviationFolded[m] = new double[nDraw]();
401 for (int j = 0; j < nDraw; ++j)
402 {
403 Mean[m][j] = 0.;
404 StandardDeviation[m][j] = 0.;
405
406 MeanFolded[m][j] = 0.;
407 StandardDeviationFolded[m][j] = 0.;
408 if(m == 0)
409 {
410 MeanGlobal[j] = 0.;
412 BetweenChainVariance[j] = 0.;
414 RHat[j] = 0.;
415 EffectiveSampleSize[j] = 0.;
416
417 MeanGlobalFolded[j] = 0.;
421 RHatFolded[j] = 0.;
423 }
424 }
425 }
426}
427
428// *******************
430// *******************
431 CalcRhat();
432 //In case in future we expand this
433}
434
435// *******************
436//KS: Based on Gelman et. al. arXiv:1903.08008v5
437// Probably most of it could be moved cleverly to MCMC Processor, keep it separate for now
438void CalcRhat() {
439// *******************
440
441 TStopwatch clock;
442 clock.Start();
443
444 //KS: Start parallel region
445 // If we would like to do this for thousands of chains we might consider using GPU for this
446 #ifdef MULTITHREAD
447 #pragma omp parallel
448 {
449 #endif
450
451 #ifdef MULTITHREAD
452 #pragma omp for collapse(2)
453 #endif
454 //KS: loop over chains and draws are independent so might as well collapse for sweet cache hits
455 //Calculate the mean for each parameter within each considered chain
456 for (int m = 0; m < Nchains; ++m)
457 {
458 for (int j = 0; j < nDraw; ++j)
459 {
460 for(int i = 0; i < Ntoys; i++)
461 {
462 Mean[m][j] += Draws[m][i][j];
463 MeanFolded[m][j] += DrawsFolded[m][i][j];
464 }
465 Mean[m][j] = Mean[m][j]/Ntoys;
466 MeanFolded[m][j] = MeanFolded[m][j]/Ntoys;
467 }
468 }
469
470 #ifdef MULTITHREAD
471 #pragma omp for
472 #endif
473 //Calculate the mean for each parameter global means we include information from several chains
474 for (int j = 0; j < nDraw; ++j)
475 {
476 for (int m = 0; m < Nchains; ++m)
477 {
478 MeanGlobal[j] += Mean[m][j];
479 MeanGlobalFolded[j] += MeanFolded[m][j];
480 }
483 }
484
485
486 #ifdef MULTITHREAD
487 #pragma omp for collapse(2)
488 #endif
489 //Calculate the standard deviation for each parameter within each considered chain
490 for (int m = 0; m < Nchains; ++m)
491 {
492 for (int j = 0; j < nDraw; ++j)
493 {
494 for(int i = 0; i < Ntoys; i++)
495 {
496 StandardDeviation[m][j] += (Draws[m][i][j] - Mean[m][j])*(Draws[m][i][j] - Mean[m][j]);
497 StandardDeviationFolded[m][j] += (DrawsFolded[m][i][j] - MeanFolded[m][j])*(DrawsFolded[m][i][j] - MeanFolded[m][j]);
498 }
499 StandardDeviation[m][j] = StandardDeviation[m][j]/(Ntoys-1);
501 }
502 }
503
504 #ifdef MULTITHREAD
505 #pragma omp for
506 #endif
507 //Calculate the standard deviation for each parameter combining information from all chains
508 for (int j = 0; j < nDraw; ++j)
509 {
510 for (int m = 0; m < Nchains; ++m)
511 {
514 }
517 }
518
519 #ifdef MULTITHREAD
520 #pragma omp for
521 #endif
522 for (int j = 0; j < nDraw; ++j)
523 {
524 //KS: This term only makes sense if we have at least 2 chains
525 if(Nchains == 1)
526 {
527 BetweenChainVariance[j] = 0.;
529 }
530 else
531 {
532 for (int m = 0; m < Nchains; ++m)
533 {
534 BetweenChainVariance[j] += ( Mean[m][j] - MeanGlobal[j])*( Mean[m][j] - MeanGlobal[j]);
536 }
539 }
540 }
541
542 #ifdef MULTITHREAD
543 #pragma omp for
544 #endif
545 for (int j = 0; j < nDraw; ++j)
546 {
549 }
550
551 #ifdef MULTITHREAD
552 #pragma omp for
553 #endif
554 //Finally calculate our estimator
555 for (int j = 0; j < nDraw; ++j)
556 {
559
560 //KS: For flat params values can be crazy so cap at 0
561 CapVariable(RHat[j], 0);
562 CapVariable(RHatFolded[j], 0);
563 }
564
565 #ifdef MULTITHREAD
566 #pragma omp for
567 #endif
568 //KS: Additionally calculates effective step size which is an estimate of the sample size required to achieve the same level of precision if that sample was a simple random sample.
569 for (int j = 0; j < nDraw; ++j)
570 {
573
574 //KS: For flat params values can be crazy so cap at 0
577 }
578 #ifdef MULTITHREAD
579 } //End parallel region
580 #endif
581
582 clock.Stop();
583 MACH3LOG_INFO("Finished calculating RHat, it took {:.2f}s to finish", clock.RealTime());
584}
585
586
587// *******************
589// *******************
590 #pragma GCC diagnostic ignored "-Wfloat-conversion"
591
592 std::string NameTemp = "";
593 //KS: If we run over many many chains there is danger that name will be so absurdly long we run over system limit and job will be killed :(
594 if(Nchains < 5)
595 {
596 for (int i = 0; i < Nchains; i++)
597 {
598 std::string temp = MCMCFile[i];
599
600 while (temp.find(".root") != std::string::npos) {
601 temp = temp.substr(0, temp.find(".root"));
602 }
603
604 NameTemp = NameTemp + temp + "_";
605 }
606 }
607 else {
608 NameTemp = std::to_string(Nchains) + "Chains" + "_";
609 }
610 NameTemp += "diag.root";
611
612 TFile* DiagFile = new TFile(NameTemp.c_str(), "recreate");
613
614 DiagFile->cd();
615
616 TH1D *StandardDeviationGlobalPlot = new TH1D("StandardDeviationGlobalPlot", "StandardDeviationGlobalPlot", nDraw, 0, nDraw);
617 TH1D *BetweenChainVariancePlot = new TH1D("BetweenChainVariancePlot", "BetweenChainVariancePlot", nDraw, 0, nDraw);
618 TH1D *MarginalPosteriorVariancePlot = new TH1D("MarginalPosteriorVariancePlot", "MarginalPosteriorVariancePlot", nDraw, 0, nDraw);
619 TH1D *RhatPlot = new TH1D("RhatPlot", "RhatPlot", 200, 0, 2);
620 TH1D *EffectiveSampleSizePlot = new TH1D("EffectiveSampleSizePlot", "EffectiveSampleSizePlot", 400, 0, 10000);
621
622 TH1D *StandardDeviationGlobalFoldedPlot = new TH1D("StandardDeviationGlobalFoldedPlot", "StandardDeviationGlobalFoldedPlot", nDraw, 0, nDraw);
623 TH1D *BetweenChainVarianceFoldedPlot = new TH1D("BetweenChainVarianceFoldedPlot", "BetweenChainVarianceFoldedPlot", nDraw, 0, nDraw);
624 TH1D *MarginalPosteriorVarianceFoldedPlot = new TH1D("MarginalPosteriorVarianceFoldedPlot", "MarginalPosteriorVarianceFoldedPlot", nDraw, 0, nDraw);
625 TH1D *RhatFoldedPlot = new TH1D("RhatFoldedPlot", "RhatFoldedPlot", 200, 0, 2);
626 TH1D *EffectiveSampleSizeFoldedPlot = new TH1D("EffectiveSampleSizeFoldedPlot", "EffectiveSampleSizeFoldedPlot", 400, 0, 10000);
627
628 TH1D *RhatLogPlot = new TH1D("RhatLogPlot", "RhatLogPlot", 200, 0, 2);
629 TH1D *RhatFoldedLogPlot = new TH1D("RhatFoldedLogPlot", "RhatFoldedLogPlot", 200, 0, 2);
630
631 int Criterium = 0;
632 int CiteriumFolded = 0;
633 for(int j = 0; j < nDraw; j++)
634 {
635 //KS: Fill only valid parameters
636 if(ValidPar[j])
637 {
638 StandardDeviationGlobalPlot->Fill(j,StandardDeviationGlobal[j]);
639 BetweenChainVariancePlot->Fill(j,BetweenChainVariance[j]);
640 MarginalPosteriorVariancePlot->Fill(j,MarginalPosteriorVariance[j]);
641 RhatPlot->Fill(RHat[j]);
642 EffectiveSampleSizePlot->Fill(EffectiveSampleSize[j]);
643 if(RHat[j] > 1.1) Criterium++;
644
645
646 StandardDeviationGlobalFoldedPlot->Fill(j,StandardDeviationGlobalFolded[j]);
647 BetweenChainVarianceFoldedPlot->Fill(j,BetweenChainVarianceFolded[j]);
648 MarginalPosteriorVarianceFoldedPlot->Fill(j,MarginalPosteriorVarianceFolded[j]);
649 RhatFoldedPlot->Fill(RHatFolded[j]);
650 EffectiveSampleSizeFoldedPlot->Fill(EffectiveSampleSizeFolded[j]);
651 if(RHatFolded[j] > 1.1) CiteriumFolded++;
652 }
653 else
654 {
655 RhatLogPlot->Fill(RHat[j]);
656 RhatFoldedLogPlot->Fill(RHatFolded[j]);
657 }
658 }
659 //KS: We set criterium of 1.1 based on Gelman et al. (2003) Bayesian Data Analysis
660 MACH3LOG_WARN("Number of parameters which has R hat greater than 1.1 is {}({:.2f}%) while for R hat folded {}({:.2f}%)", Criterium, 100*double(Criterium)/double(nDraw), CiteriumFolded, 100*double(CiteriumFolded)/double(nDraw));
661 for(int j = 0; j < nDraw; j++)
662 {
663 if( (RHat[j] > 1.1 || RHatFolded[j] > 1.1) && ValidPar[j])
664 {
665 MACH3LOG_CRITICAL("Parameter {} has R hat higher than 1.1", BranchNames[j]);
666 }
667 }
668 StandardDeviationGlobalPlot->Write();
669 BetweenChainVariancePlot->Write();
670 MarginalPosteriorVariancePlot->Write();
671 RhatPlot->Write();
672 EffectiveSampleSizePlot->Write();
673
674 StandardDeviationGlobalFoldedPlot->Write();
675 BetweenChainVarianceFoldedPlot->Write();
676 MarginalPosteriorVarianceFoldedPlot->Write();
677 RhatFoldedPlot->Write();
678 EffectiveSampleSizeFoldedPlot->Write();
679
680 RhatLogPlot->Write();
681 RhatFoldedLogPlot->Write();
682
683 //KS: Now we make fancy canvases, consider some function to have less copy pasting
684 auto TempCanvas = std::make_unique<TCanvas>("Canvas", "Canvas", 1024, 1024);
685 gStyle->SetOptStat(0);
686 TempCanvas->SetGridx();
687 TempCanvas->SetGridy();
688
689 // Random line to write useful information to TLegend
690 auto TempLine = std::make_unique<TLine>(0, 0, 0, 0);
691 TempLine->SetLineColor(kBlack);
692
693 RhatPlot->GetXaxis()->SetTitle("R hat");
694 RhatPlot->SetLineColor(kRed);
695 RhatPlot->SetFillColor(kRed);
696 RhatFoldedPlot->SetLineColor(kBlue);
697 RhatFoldedPlot->SetFillColor(kBlue);
698
699 TLegend *Legend = new TLegend(0.55, 0.6, 0.9, 0.9);
700 Legend->SetTextSize(0.04);
701 Legend->SetFillColor(0);
702 Legend->SetFillStyle(0);
703 Legend->SetLineWidth(0);
704 Legend->SetLineColor(0);
705
706 Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
707 Legend->AddEntry(RhatPlot, "Rhat Gelman 2013", "l");
708 Legend->AddEntry(RhatFoldedPlot, "Rhat-Folded Gelman 2021", "l");
709
710 RhatPlot->Draw();
711 RhatFoldedPlot->Draw("same");
712 Legend->Draw("same");
713 TempCanvas->Write("Rhat");
714 delete Legend;
715 Legend = nullptr;
716
717 //Now R hat for log L
718 RhatLogPlot->GetXaxis()->SetTitle("R hat for LogL");
719 RhatLogPlot->SetLineColor(kRed);
720 RhatLogPlot->SetFillColor(kRed);
721 RhatFoldedLogPlot->SetLineColor(kBlue);
722 RhatFoldedLogPlot->SetFillColor(kBlue);
723
724 Legend = new TLegend(0.55, 0.6, 0.9, 0.9);
725 Legend->SetTextSize(0.04);
726 Legend->SetFillColor(0);
727 Legend->SetFillStyle(0);
728 Legend->SetLineWidth(0);
729 Legend->SetLineColor(0);
730
731 Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
732 Legend->AddEntry(RhatLogPlot, "Rhat Gelman 2013", "l");
733 Legend->AddEntry(RhatFoldedLogPlot, "Rhat-Folded Gelman 2021", "l");
734
735 RhatLogPlot->Draw();
736 RhatFoldedLogPlot->Draw("same");
737 Legend->Draw("same");
738 TempCanvas->Write("RhatLog");
739 delete Legend;
740 Legend = nullptr;
741
742 //Now canvas for effective sample size
743 EffectiveSampleSizePlot->GetXaxis()->SetTitle("S_{eff, BDA2}");
744 EffectiveSampleSizePlot->SetLineColor(kRed);
745 EffectiveSampleSizeFoldedPlot->SetLineColor(kBlue);
746
747 Legend = new TLegend(0.45, 0.6, 0.9, 0.9);
748 Legend->SetTextSize(0.03);
749 Legend->SetFillColor(0);
750 Legend->SetFillStyle(0);
751 Legend->SetLineWidth(0);
752 Legend->SetLineColor(0);
753
754 const double Mean1 = EffectiveSampleSizePlot->GetMean();
755 const double RMS1 = EffectiveSampleSizePlot->GetRMS();
756 const double Mean2 = EffectiveSampleSizeFoldedPlot->GetMean();
757 const double RMS2 = EffectiveSampleSizeFoldedPlot->GetRMS();
758
759 Legend->AddEntry(TempLine.get(), Form("Number of throws=%.0i, Number of chains=%.1i", Ntoys, Nchains), "");
760 Legend->AddEntry(EffectiveSampleSizePlot, Form("S_{eff, BDA2} #mu = %.2f, #sigma = %.2f",Mean1 ,RMS1), "l");
761 Legend->AddEntry(EffectiveSampleSizeFoldedPlot, Form("S_{eff, BDA2} Folded, #mu = %.2f, #sigma = %.2f",Mean2 ,RMS2), "l");
762
763 EffectiveSampleSizePlot->Draw();
764 EffectiveSampleSizeFoldedPlot->Draw("same");
765 Legend->Draw("same");
766 TempCanvas->Write("EffectiveSampleSize");
767
768 //Fancy memory cleaning
769 delete StandardDeviationGlobalPlot;
770 delete BetweenChainVariancePlot;
771 delete MarginalPosteriorVariancePlot;
772 delete RhatPlot;
773 delete EffectiveSampleSizePlot;
774
775 delete StandardDeviationGlobalFoldedPlot;
776 delete BetweenChainVarianceFoldedPlot;
777 delete MarginalPosteriorVarianceFoldedPlot;
778 delete RhatFoldedPlot;
779 delete EffectiveSampleSizeFoldedPlot;
780
781 delete Legend;
782
783 delete RhatLogPlot;
784 delete RhatFoldedLogPlot;
785
786 DiagFile->Close();
787 delete DiagFile;
788
789 MACH3LOG_INFO("Finished and wrote results to {}", NameTemp);
790}
791
792// *******************
793//KS: Pseudo destructor
795// *******************
796
797 MACH3LOG_INFO("Killing all arrays");
798 delete[] MeanGlobal;
800 delete[] BetweenChainVariance;
802 delete[] RHat;
803 delete[] EffectiveSampleSize;
804
805 delete[] MeanGlobalFolded;
809 delete[] RHatFolded;
811
812 for(int m = 0; m < Nchains; m++)
813 {
814 for(int i = 0; i < Ntoys; i++)
815 {
816 delete[] Draws[m][i];
817 delete[] DrawsFolded[m][i];
818 }
819 delete[] Draws[m];
820 delete[] Mean[m];
821 delete[] StandardDeviation[m];
822
823 delete[] DrawsFolded[m];
824 delete[] MeanFolded[m];
825 delete[] StandardDeviationFolded[m];
826 }
827 delete[] Draws;
828 delete[] Mean;
829 delete[] StandardDeviation;
830
831 delete[] DrawsFolded;
832 delete[] MedianArr;
833 delete[] MeanFolded;
835}
836
837// *******************
838//calculate median
839double CalcMedian(double arr[], const int size) {
840// *******************
841 std::sort(arr, arr+size);
842 if (size % 2 != 0)
843 return arr[size/2];
844 return (arr[(size-1)/2] + arr[size/2])/2.0;
845}
846
847// *******************
848//calculate median
849void CapVariable(double var, const double cap) {
850// *******************
851 if(std::isnan(var) || !std::isfinite(var)) var = cap;
852}
#define _MaCh3_Safe_Include_Start_
KS: Avoiding warning checking for headers.
Definition: Core.h:106
#define _MaCh3_Safe_Include_End_
KS: Restore warning checking after including external headers.
Definition: Core.h:117
int size
#define MACH3LOG_CRITICAL
Definition: MaCh3Logger.h:26
#define MACH3LOG_DEBUG
Definition: MaCh3Logger.h:22
#define MACH3LOG_ERROR
Definition: MaCh3Logger.h:25
#define MACH3LOG_INFO
Definition: MaCh3Logger.h:23
void SetMaCh3LoggerFormat()
Set messaging format of the logger.
Definition: MaCh3Logger.h:30
#define MACH3LOG_WARN
Definition: MaCh3Logger.h:24
int main(int argc, char *argv[])
void SaveResults()
int Nchains
double * StandardDeviationGlobalFolded
void CapVariable(double var, double cap)
double * EffectiveSampleSizeFolded
double * BetweenChainVarianceFolded
void InitialiseArrays()
double * MeanGlobal
double ** MeanFolded
void RunDiagnostic()
std::vector< bool > ValidPar
double * BetweenChainVariance
double * MeanGlobalFolded
double ** StandardDeviation
double * StandardDeviationGlobal
double * MarginalPosteriorVariance
double * RHatFolded
double * MedianArr
double *** DrawsFolded
double ** StandardDeviationFolded
std::vector< TString > BranchNames
double *** Draws
std::vector< std::string > MCMCFile
double * RHat
double ** Mean
void DestroyArrays()
void CalcRhat()
double CalcMedian(double arr[], int size)
int nDraw
double * EffectiveSampleSize
void PrepareChains()
double * MarginalPosteriorVarianceFolded
int Ntoys
Custom exception class for MaCh3 errors.
void PrintProgressBar(const Long64_t Done, const Long64_t All)
KS: Simply print progress bar.
Definition: Monitor.cpp:212
void MaCh3Welcome()
KS: Prints welcome message with MaCh3 logo.
Definition: Monitor.cpp:11