Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 156 additions & 52 deletions src/endf/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: MIT

from collections.abc import Iterable
from math import exp, log

import numpy as np

Expand Down Expand Up @@ -95,27 +94,9 @@ def __call__(self, x):
xi1 = self.x[idx[contained] + 1] # high edge of corresponding bins
yi = self.y[idx[contained]]
yi1 = self.y[idx[contained] + 1]

if self.interpolation[k] == 1:
# Histogram
y[contained] = yi

elif self.interpolation[k] == 2:
# Linear-linear
y[contained] = yi + (xk - xi)/(xi1 - xi)*(yi1 - yi)

elif self.interpolation[k] == 3:
# Linear-log
y[contained] = yi + np.log(xk/xi)/np.log(xi1/xi)*(yi1 - yi)

elif self.interpolation[k] == 4:
# Log-linear
y[contained] = yi*np.exp((xk - xi)/(xi1 - xi)*np.log(yi1/yi))

elif self.interpolation[k] == 5:
# Log-log
y[contained] = (yi*np.exp(np.log(xk/xi)/np.log(xi1/xi)
*np.log(yi1/yi)))

p = self.interpolation[k]
y[contained] = self._interpolate(p, xk, xi, yi, xi1, yi1)

# In some cases, x values might be outside the tabulated region due only
# to precision, so we check if they're close and set them equal if so.
Expand Down Expand Up @@ -143,6 +124,10 @@ def _interpolate_scalar(self, x):
yi = self._y[idx]
yi1 = self._y[idx + 1]

return self._interpolate(p, x, xi, yi, xi1, yi1)

@staticmethod
def _interpolate(p, x, xi, yi, xi1, yi1):
if p == 1:
# Histogram
return yi
Expand All @@ -153,15 +138,18 @@ def _interpolate_scalar(self, x):

elif p == 3:
# Linear-log
return yi + log(x/xi)/log(xi1/xi)*(yi1 - yi)
return yi + np.log(x/xi)/np.log(xi1/xi)*(yi1 - yi)

elif p == 4:
# Log-linear
return yi*exp((x - xi)/(xi1 - xi)*log(yi1/yi))
return yi*np.exp((x - xi)/(xi1 - xi)*np.log(yi1/yi))

elif p == 5:
# Log-log
return yi*exp(log(x/xi)/log(xi1/xi)*log(yi1/yi))
return yi*np.exp(np.log(x/xi)/np.log(xi1/xi)*np.log(yi1/yi))

else:
raise ValueError(f"Unknown interpolation rule {p}")

def __len__(self):
return len(self.x)
Expand Down Expand Up @@ -230,38 +218,154 @@ def integral(self):
x1 = self.x[i_low + 1:i_high + 1]
y0 = self.y[i_low:i_high]
y1 = self.y[i_low + 1:i_high + 1]

if self.interpolation[k] == 1:
# Histogram
partial_sum[i_low:i_high] = y0*(x1 - x0)

elif self.interpolation[k] == 2:
# Linear-linear
m = (y1 - y0)/(x1 - x0)
partial_sum[i_low:i_high] = (y0 - m*x0)*(x1 - x0) + \
m*(x1**2 - x0**2)/2

elif self.interpolation[k] == 3:
# Linear-log
logx = np.log(x1/x0)
m = (y1 - y0)/logx
partial_sum[i_low:i_high] = y0 + m*(x1*(logx - 1) + x0)

elif self.interpolation[k] == 4:
# Log-linear
m = np.log(y1/y0)/(x1 - x0)
partial_sum[i_low:i_high] = y0/m*(np.exp(m*(x1 - x0)) - 1)

elif self.interpolation[k] == 5:
# Log-log
m = np.log(y1/y0)/np.log(x1/x0)
partial_sum[i_low:i_high] = y0/((m + 1)*x0**m)*(
x1**(m + 1) - x0**(m + 1))

p = self.interpolation[k]
partial_sum[i_low:i_high] = self._integrate(p, x0, y0, x1, y1)

i_low = i_high

return np.concatenate(([0.], np.cumsum(partial_sum)))

def integrate(self, a: float, b: float, clip_bounds: bool = True) -> float:
"""Performs a definite integral over a portion of the function.

Performs the definite integral over the range [a,b] for the tabulated
function. If clip_bounds is True (default), then the integration bounds
will be truncated to be within the tabulated domain. This is equivalent
to the function being zero outside the tabulation. If clip_bounds is
False, then a ValueError is raised if a or b are outside the domain.

Parameters
----------
a : float
Lower bound of integration.
b : float
Upper bound of integration.
clip_bounds : bool, default True
If True, the integration bounds are clipped to cover the tabulated
domain of the function. If False, a ValueError will be raised if
either integration bound is outside the tabulated domain.

Returns
-------
float
Value of the definite integral.

Raises
------
ValueError
If clip_bounds is False and either of the integration bounds a or b
are outside the domain of the tabulated function.
"""
# If the integration range doesn't go from low to high, flip the order
flipped = False
if a > b:
flipped = True
tmp = a
a = b
b = tmp

# Next we check that the integration range is valid. If the bounds go
# off the grid, we clip them if clib_bounds is True. Otherwise, we
# raise an exception.
if not clip_bounds:
if a < self._x[0] or self._x[-1] < b:
raise ValueError("Integration bounds are outside of the "
"tabulated function domain.")
else:
# Clip the bounds if necessary
if a < self._x[0]:
a = self._x[0]
if self._x[-1] < b:
b = self._x[-1]

# Check for this special case
if a == b:
return 0.

# This function finds the interpolation rule for a given index
def get_interpolation(indx: int) -> int:
# Loop over interpolation regions
for b, p in zip(self.breakpoints, self.interpolation):
if indx < b - 1:
return p
# Should never get here
return self.interpolation[-1]

# Now we can start to perform the real integral
integral = 0.
x_lower_bound = a
x_upper_bound = b

# Get the first index for interpolation
idx = np.searchsorted(self._x, x_lower_bound, side='right') - 1

while idx < self._x.size-1:
# Get the interpolation rule
interp = get_interpolation(idx)

# Get tabulated values
xi = self._x[idx] # low edge of the corresponding bin
xi1 = self._x[idx + 1] # high edge of the corresponding bin
yi = self._y[idx]
yi1 = self._y[idx + 1]

# If we are at one of the end points, perform the necessary interpolation
if xi < x_lower_bound:
yi = self._interpolate(interp, x_lower_bound, xi, yi, xi1, yi1)
xi = x_lower_bound

if x_upper_bound < xi1:
yi1 = self._interpolate(interp, x_upper_bound, xi, yi, xi1, yi1)
xi1 = x_upper_bound

# Add check to ensure the loop will stop
if x_upper_bound == xi1:
idx = self._x.size

# Contribute to the integral
integral += self._integrate(interp, xi, yi, xi1, yi1)

# Prepare for next iteration
idx += 1
x_lower_bound = xi1

# If we had to flip integration bounds, multiply by -1
if flipped:
integral = -integral

return integral

@staticmethod
def _integrate(p, x0, y0, x1, y1):
if p == 1:
# Histogram
return y0*(x1 - x0)

elif p == 2:
# Linear-linear
m = (y1 - y0)/(x1 - x0)
return (y0 - m*x0)*(x1 - x0) + 0.5*m*(x1**2 - x0**2)

elif p == 3:
# Linear-log
logx = np.log(x1/x0)
m = (y1 - y0)/logx
return y0 + m*(x1*(logx - 1) + x0)

elif p == 4:
# Log-linear
m = np.log(y1/y0)/(x1 - x0)
return y0/m*(np.exp(m*(x1 - x0)) - 1)

elif p == 5:
# Log-log
m = np.log(y1/y0)/np.log(x1/x0)
return y0/((m + 1)*x0**m)*(x1**(m + 1) - x0**(m + 1))

else:
raise ValueError(f"Unknown interpolation rule {p}")

@classmethod
def from_ace(cls, ace, idx=0, convert_units=True):
"""Create a Tabulated1D object from an ACE table.
Expand Down