diff --git a/src/endf/function.py b/src/endf/function.py index 1fa05db..15e179f 100644 --- a/src/endf/function.py +++ b/src/endf/function.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: MIT from collections.abc import Iterable -from math import exp, log import numpy as np @@ -95,27 +94,9 @@ def __call__(self, x): xi1 = self.x[idx[contained] + 1] # high edge of corresponding bins yi = self.y[idx[contained]] yi1 = self.y[idx[contained] + 1] - - if self.interpolation[k] == 1: - # Histogram - y[contained] = yi - - elif self.interpolation[k] == 2: - # Linear-linear - y[contained] = yi + (xk - xi)/(xi1 - xi)*(yi1 - yi) - - elif self.interpolation[k] == 3: - # Linear-log - y[contained] = yi + np.log(xk/xi)/np.log(xi1/xi)*(yi1 - yi) - - elif self.interpolation[k] == 4: - # Log-linear - y[contained] = yi*np.exp((xk - xi)/(xi1 - xi)*np.log(yi1/yi)) - - elif self.interpolation[k] == 5: - # Log-log - y[contained] = (yi*np.exp(np.log(xk/xi)/np.log(xi1/xi) - *np.log(yi1/yi))) + + p = self.interpolation[k] + y[contained] = self._interpolate(p, xk, xi, yi, xi1, yi1) # In some cases, x values might be outside the tabulated region due only # to precision, so we check if they're close and set them equal if so. @@ -143,6 +124,10 @@ def _interpolate_scalar(self, x): yi = self._y[idx] yi1 = self._y[idx + 1] + return self._interpolate(p, x, xi, yi, xi1, yi1) + + @staticmethod + def _interpolate(p, x, xi, yi, xi1, yi1): if p == 1: # Histogram return yi @@ -153,15 +138,18 @@ def _interpolate_scalar(self, x): elif p == 3: # Linear-log - return yi + log(x/xi)/log(xi1/xi)*(yi1 - yi) + return yi + np.log(x/xi)/np.log(xi1/xi)*(yi1 - yi) elif p == 4: # Log-linear - return yi*exp((x - xi)/(xi1 - xi)*log(yi1/yi)) + return yi*np.exp((x - xi)/(xi1 - xi)*np.log(yi1/yi)) elif p == 5: # Log-log - return yi*exp(log(x/xi)/log(xi1/xi)*log(yi1/yi)) + return yi*np.exp(np.log(x/xi)/np.log(xi1/xi)*np.log(yi1/yi)) + + else: + raise ValueError(f"Unknown interpolation rule {p}") def __len__(self): return len(self.x) @@ -230,38 +218,154 @@ def integral(self): x1 = self.x[i_low + 1:i_high + 1] y0 = self.y[i_low:i_high] y1 = self.y[i_low + 1:i_high + 1] - - if self.interpolation[k] == 1: - # Histogram - partial_sum[i_low:i_high] = y0*(x1 - x0) - - elif self.interpolation[k] == 2: - # Linear-linear - m = (y1 - y0)/(x1 - x0) - partial_sum[i_low:i_high] = (y0 - m*x0)*(x1 - x0) + \ - m*(x1**2 - x0**2)/2 - - elif self.interpolation[k] == 3: - # Linear-log - logx = np.log(x1/x0) - m = (y1 - y0)/logx - partial_sum[i_low:i_high] = y0 + m*(x1*(logx - 1) + x0) - - elif self.interpolation[k] == 4: - # Log-linear - m = np.log(y1/y0)/(x1 - x0) - partial_sum[i_low:i_high] = y0/m*(np.exp(m*(x1 - x0)) - 1) - - elif self.interpolation[k] == 5: - # Log-log - m = np.log(y1/y0)/np.log(x1/x0) - partial_sum[i_low:i_high] = y0/((m + 1)*x0**m)*( - x1**(m + 1) - x0**(m + 1)) + + p = self.interpolation[k] + partial_sum[i_low:i_high] = self._integrate(p, x0, y0, x1, y1) i_low = i_high return np.concatenate(([0.], np.cumsum(partial_sum))) + def integrate(self, a: float, b: float, clip_bounds: bool = True) -> float: + """Performs a definite integral over a portion of the function. + + Performs the definite integral over the range [a,b] for the tabulated + function. If clip_bounds is True (default), then the integration bounds + will be truncated to be within the tabulated domain. This is equivalent + to the function being zero outside the tabulation. If clip_bounds is + False, then a ValueError is raised if a or b are outside the domain. + + Parameters + ---------- + a : float + Lower bound of integration. + b : float + Upper bound of integration. + clip_bounds : bool, default True + If True, the integration bounds are clipped to cover the tabulated + domain of the function. If False, a ValueError will be raised if + either integration bound is outside the tabulated domain. + + Returns + ------- + float + Value of the definite integral. + + Raises + ------ + ValueError + If clip_bounds is False and either of the integration bounds a or b + are outside the domain of the tabulated function. + """ + # If the integration range doesn't go from low to high, flip the order + flipped = False + if a > b: + flipped = True + tmp = a + a = b + b = tmp + + # Next we check that the integration range is valid. If the bounds go + # off the grid, we clip them if clib_bounds is True. Otherwise, we + # raise an exception. + if not clip_bounds: + if a < self._x[0] or self._x[-1] < b: + raise ValueError("Integration bounds are outside of the " + "tabulated function domain.") + else: + # Clip the bounds if necessary + if a < self._x[0]: + a = self._x[0] + if self._x[-1] < b: + b = self._x[-1] + + # Check for this special case + if a == b: + return 0. + + # This function finds the interpolation rule for a given index + def get_interpolation(indx: int) -> int: + # Loop over interpolation regions + for b, p in zip(self.breakpoints, self.interpolation): + if indx < b - 1: + return p + # Should never get here + return self.interpolation[-1] + + # Now we can start to perform the real integral + integral = 0. + x_lower_bound = a + x_upper_bound = b + + # Get the first index for interpolation + idx = np.searchsorted(self._x, x_lower_bound, side='right') - 1 + + while idx < self._x.size-1: + # Get the interpolation rule + interp = get_interpolation(idx) + + # Get tabulated values + xi = self._x[idx] # low edge of the corresponding bin + xi1 = self._x[idx + 1] # high edge of the corresponding bin + yi = self._y[idx] + yi1 = self._y[idx + 1] + + # If we are at one of the end points, perform the necessary interpolation + if xi < x_lower_bound: + yi = self._interpolate(interp, x_lower_bound, xi, yi, xi1, yi1) + xi = x_lower_bound + + if x_upper_bound < xi1: + yi1 = self._interpolate(interp, x_upper_bound, xi, yi, xi1, yi1) + xi1 = x_upper_bound + + # Add check to ensure the loop will stop + if x_upper_bound == xi1: + idx = self._x.size + + # Contribute to the integral + integral += self._integrate(interp, xi, yi, xi1, yi1) + + # Prepare for next iteration + idx += 1 + x_lower_bound = xi1 + + # If we had to flip integration bounds, multiply by -1 + if flipped: + integral = -integral + + return integral + + @staticmethod + def _integrate(p, x0, y0, x1, y1): + if p == 1: + # Histogram + return y0*(x1 - x0) + + elif p == 2: + # Linear-linear + m = (y1 - y0)/(x1 - x0) + return (y0 - m*x0)*(x1 - x0) + 0.5*m*(x1**2 - x0**2) + + elif p == 3: + # Linear-log + logx = np.log(x1/x0) + m = (y1 - y0)/logx + return y0 + m*(x1*(logx - 1) + x0) + + elif p == 4: + # Log-linear + m = np.log(y1/y0)/(x1 - x0) + return y0/m*(np.exp(m*(x1 - x0)) - 1) + + elif p == 5: + # Log-log + m = np.log(y1/y0)/np.log(x1/x0) + return y0/((m + 1)*x0**m)*(x1**(m + 1) - x0**(m + 1)) + + else: + raise ValueError(f"Unknown interpolation rule {p}") + @classmethod def from_ace(cls, ace, idx=0, convert_units=True): """Create a Tabulated1D object from an ACE table.