MaCh3  2.4.2
Reference Guide
splines.cpp
Go to the documentation of this file.
1 // pybind includes
2 #include <pybind11/pybind11.h>
3 #include <pybind11/stl.h>
4 #include <pybind11/numpy.h>
5 // MaCh3 includes
6 #include "Splines/SplineBase.h"
9 #include "Samples/SampleStructs.h" // <- The spline stuff that's in here should really be moved to splineStructs.h but I ain't doing that right now
10 // ROOT includes
11 #include "TSpline.h"
12 
13 #pragma GCC diagnostic ignored "-Wuseless-cast"
14 #pragma GCC diagnostic ignored "-Wfloat-conversion"
15 
16 namespace py = pybind11;
17 
19 class PySplineBase : public SplineBase {
20 public:
21  /* Inherit the constructors */
23 
24  /* Trampoline (need one for each virtual function) */
25  void Evaluate() override {
26  PYBIND11_OVERRIDE_PURE_NAME(
27  void, /* Return type */
28  SplineBase, /* Parent class */
29  "evaluate", /* Name in python*/
30  Evaluate /* Name of function in C++ (must match Python name) */
31  );
32  }
33 
34  std::string GetName() const override {
35  PYBIND11_OVERRIDE_NAME(
36  std::string, /* Return type */
37  SplineBase, /* Parent class */
38  "get_name", /* Name in python*/
39  GetName /* Name of function in C++ (must match Python name) */
40  );
41  }
42 
44  PYBIND11_OVERRIDE_PURE_NAME(
45  void, /* Return type */
46  SplineBase, /* Parent class */
47  "find_segment", /* Name in python*/
48  FindSplineSegment /* Name of function in C++ (must match Python name) */
49  );
50  }
51 
52  void CalcSplineWeights() override {
53  PYBIND11_OVERRIDE_PURE_NAME(
54  void, /* Return type */
55  SplineBase, /* Parent class */
56  "calculate_weights", /* Name in python*/
57  CalcSplineWeights /* Name of function in C++ (must match Python name) */
58  );
59  }
60 
61  void ModifyWeights() override {
62  PYBIND11_OVERRIDE_PURE_NAME(
63  void, /* Return type */
64  SplineBase, /* Parent class */
65  "modify_weights", /* Name in python*/
66  ModifyWeights /* Name of function in C++ (must match Python name) */
67  );
68  }
69 };
70 
71 
72 void initSplines(py::module &m) {
73  auto m_splines = m.def_submodule("splines");
74  m_splines.doc() =
75  "This is a Python binding of MaCh3s C++ based spline library.";
76 
77  // Bind the interpolation type enum that lets us set different interpolation types for our splines
78  py::enum_<SplineInterpolation>(m_splines, "InterpolationType")
79  .value("Linear", SplineInterpolation::kLinear, "Linear interpolation between the knots")
80  .value("Linear_Func", SplineInterpolation::kLinearFunc, "Same as 'Linear'")
81  .value("Cubic_TSpline3", SplineInterpolation::kTSpline3, "Use same coefficients as `ROOT's TSpline3 <https://root.cern.ch/doc/master/classTSpline3.html>`_ implementation")
82  .value("Cubic_Monotonic", SplineInterpolation::kMonotonic, "Coefficients are calculated such that the segments between knots are forced to be monotonic. The implementation we use is based on `this method <https://www.jstor.org/stable/2156610>`_ by Fritsch and Carlson.")
83  .value("Cubic_Akima", SplineInterpolation::kAkima, "The second derivative is not required to be continuous at the knots. This means that these splines are useful if the second derivative is rapidly varying. The implementation we used is based on `this paper <http://www.leg.ufpr.br/lib/exe/fetch.php/wiki:internas:biblioteca:akima.pdf>`_ by Akima.")
84  .value("N_Interpolation_Types", SplineInterpolation::kSplineInterpolations, "This is only to be used when iterating and is not a valid interpolation type.");
85 
86 
87  py::class_<SplineBase, PySplineBase /* <--- trampoline*/>(m_splines, "SplineBase");
88 
89  py::class_<TResponseFunction_red>(m_splines, "_ResponseFunctionBase")
90  .doc() = "Base class of the response function, this binding only exists for consistency with the inheritance structure of the c++ code. Just pretend it doesn't exist and don't worry about it...";
91 
92  // Bind the TSpline3_red class. Decided to go with a clearer name of ResponseFunction for the python binding
93  // and make the interface a bit more python-y. Additionally remove passing root stuff so we don't need to deal
94  // with root python binding and can just pass it native python objects.
95  py::class_<TSpline3_red, TResponseFunction_red, std::unique_ptr<TSpline3_red, py::nodelete>>(m_splines, "ResponseFunction")
96  .def(
97  // define a more python friendly constructor that massages the inputs and passes them
98  // through to the c++ constructor
99  py::init
100  (
101  // Just take in some vectors, then build a TSpline3 and pass this to the constructor
102  [](std::vector<double> &xVals, std::vector<double> &yVals, SplineInterpolation interpType)
103  {
104  if ( xVals.size() != yVals.size() )
105  {
106  throw MaCh3Exception(__FILE__, __LINE__, "Different number of x values and y values!");
107  }
108 
109  int length = int(xVals.size());
110 
111  if (length == 1)
112  {
113  M3::float_t xKnot = M3::float_t(xVals[0]);
114  M3::float_t yKnot = M3::float_t(yVals[0]);
115 
116  std::vector<M3::float_t *> pars;
117  pars.resize(3);
118  pars[0] = new M3::float_t(0.0);
119  pars[1] = new M3::float_t(0.0);
120  pars[2] = new M3::float_t(0.0);
121  delete pars[0];
122  delete pars[1];
123  delete pars[2];
124 
125  return new TSpline3_red(&xKnot, &yKnot, 1, pars.data());
126  }
127 
128  TSpline3 *splineTmp = new TSpline3( "spline_tmp", xVals.data(), yVals.data(), length );
129  return new TSpline3_red(splineTmp, interpType);
130  }
131  )
132  )
133 
134  .def(
135  "find_segment",
137  "Find the segment that a particular *value* lies in. \n"
138  ":param value: The value to test",
139  py::arg("value")
140  )
141 
142  .def(
143  "evaluate",
145  "Evaluate the response function at a particular *value*. \n"
146  ":param value: The value to evaluate at.",
147  py::arg("value")
148  )
149  ; // End of binding for ResponseFunction
150 
151  py::class_<SMonolith, SplineBase>(m_splines, "EventSplineMonolith")
152  .def(
153  py::init(
154  [](std::vector<std::vector<TResponseFunction_red*>> &responseFns, const bool saveFlatTree)
155  {
156  std::vector<RespFuncType> respFnTypes;
157  for(uint i = 0; i < responseFns[0].size(); i++)
158  {
159  // ** WARNING **
160  // Right now I'm only pushing back TSpline3_reds as that's all that's supported right now
161  // In the future there might be more
162  // I think what would be best to do would be to store the interpolation type somehow in the ResponseFunction objects
163  // then just read them here and pass through to the constructor
164  respFnTypes.push_back(RespFuncType::kTSpline3_red);
165  }
166  return new SMonolith(responseFns, respFnTypes, saveFlatTree);
167  }
168  ),
169  "Create an EventSplineMonolith \n"
170  ":param master_splines: These are the 'knot' values to make splines from. This should be an P x E 2D list where P is the number of parameters and E is the number of events. \n"
171  ":param save_flat_tree: Whether we want to save monolith into speedy flat tree",
172  py::arg("master_splines"),
173  py::arg("save_flat_tree") = false
174  )
175 
176  .def(
177  py::init<std::string>(),
178  "Constructor where you pass path to preprocessed root FileName which is generated by creating an EventSplineMonolith with the `save_flat_tree` flag set to True. \n"
179  ":param file_name: The name of the file to read from.",
180  py::arg("file_name")
181  )
182 
183  .def(
184  "evaluate",
186  "Evaluate the splines at their current values."
187  )
188 
189  .def(
190  "sync_mem_transfer",
192  "This is important when running on GPU. After calculations are done on GPU we copy memory to CPU. This operation is asynchronous meaning while memory is being copied some operations are being carried. Memory must be copied before actual reweight. This function make sure all has been copied."
193  )
194 
195  .def(
196  "get_event_weight",
198  py::return_value_policy::reference,
199  "Get the weight of a particular event. \n"
200  ":param event: The index of the event whose weight you would like.",
201  py::arg("event")
202  )
203 
204  .def(
205  "set_param_value_array",
206  // Wrap up the setSplinePointers method so that we can take in a numpy array and get
207  // pointers to it's sweet sweet data and use those pointers in the splineMonolith
208  [](SMonolith &self, py::array_t<double, py::array::c_style> &array)
209  {
210  py::buffer_info bufInfo = array.request();
211 
212  if ( bufInfo.ndim != 1)
213  {
214  throw MaCh3Exception(__FILE__, __LINE__, "Number of dimensions in parameter array must be one!");
215  }
216 
217  if ( bufInfo.shape[0] != self.GetNParams() )
218  {
219  throw MaCh3Exception(__FILE__, __LINE__, "Number of entries in parameter array must equal the number of parameters!");
220  }
221 
222  std::vector<const double *> paramVec;
223  paramVec.resize(self.GetNParams());
224 
225  for( int idx = 0; idx < self.GetNParams(); idx++ )
226  {
227  // booooo pointer arithmetic
228  paramVec[idx] = array.data() + idx;
229  }
230 
231  self.setSplinePointers(paramVec);
232  },
233  "Set the array that the monolith should use to read parameter values from. \n"
234  "Usage of this might vary a bit from what you're used to in python. \n"
235  "Rather than just setting the values here, what you're really doing is setting pointers in the underlying c++ code. \n"
236  "What that means is that you pass an array to this function like:: \n"
237  "\n event_spline_monolith_instance.set_param_value_array(array) \n\n"
238  "Then when you set values in that array as normal, they will also be updated inside of the event_spline_monolith_instance.",
239  py::arg("array")
240 
241  )
242 
243  .doc() = "This 'monolith' deals with event by event weighting using splines."
244 
245  ; // End of binding for EventSplineMonolith
246 }
SplineInterpolation
Make an enum of the spline interpolation type.
@ kTSpline3
Default TSpline3 interpolation.
@ kMonotonic
EM: DOES NOT make the entire spline monotonic, only the segments.
@ kSplineInterpolations
This only enumerates.
@ kLinear
Linear interpolation between knots.
@ kLinearFunc
Liner interpolation using TF1 not spline.
@ kAkima
EM: Akima spline iis allowed to be discontinuous in 2nd derivative and coefficients in any segment.
@ kTSpline3_red
Uses TSpline3_red for interpolation.
Contains structures and helper functions for handling spline representations of systematic parameters...
Custom exception class used throughout MaCh3.
EW: As SplineBase is an abstract base class we have to do some gymnastics to get it to get it into py...
Definition: splines.cpp:19
std::string GetName() const override
Get class name.
Definition: splines.cpp:34
void FindSplineSegment()
Definition: splines.cpp:43
void CalcSplineWeights() override
CPU based code which eval weight for each spline.
Definition: splines.cpp:52
void ModifyWeights() override
Calc total event weight.
Definition: splines.cpp:61
void Evaluate() override
CW: This Eval should be used when using two separate x,{y,a,b,c,d} arrays to store the weights; proba...
Definition: splines.cpp:25
Even-by-event class calculating response for spline parameters. It is possible to use GPU acceleratio...
void SynchroniseMemTransfer() const override
KS: After calculations are done on GPU we copy memory to CPU. This operation is asynchronous meaning ...
void Evaluate() override
CW: This Eval should be used when using two separate x,{y,a,b,c,d} arrays to store the weights; proba...
const float * retPointer(const int event) const
KS: Get pointer to total weight to make fit faster wrooom!
Base class for calculating weight from spline.
Definition: SplineBase.h:25
SplineBase()
Constructor.
Definition: SplineBase.cpp:7
CW: Reduced TSpline3 class.
double Eval(double var) override
CW: Evaluate the weight from a variation.
int FindX(double x)
Find the segment relevant to this variation in x.
double float_t
Definition: Core.h:37
void initSplines(py::module &m)
Definition: splines.cpp:72