Source code for greens_basis_transform
# Taken directly from the repository associated with
# Agol, Luger, and Foreman-Mackey 2020 (doi:10.3847/1538-3881/ab4fee)
# Many thanks to the authors for making this code available!
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from sympy import binomial, symbols, zeros
# Define our symbols
z, n = symbols("z n")
def _ptilde(n, z):
"""Return the n^th term in the polynomial basis."""
return z**n
def _Coefficient(expression, term):
"""Return the coefficient multiplying `term` in `expression`."""
# Get the coefficient
coeff = expression.coeff(term)
# Set any non-constants in this coefficient to zero. If the coefficient
# is not a constant, this is not the term we are interested in!
coeff = coeff.subs(z, 0)
return coeff
def _A1(N):
"""Return the change of basis matrix A1."""
res = zeros(N + 1, N + 1)
for i in range(N + 1):
for j in range(N + 1):
res[i, j] = (-1) ** (i + 1) * binomial(j, i)
return res
def _gtilde(n, z):
"""Return the n^th term in the Green's basis."""
if n == 0:
return 1 + 0 * z
elif n == 1:
return z
else:
return (n + 2) * z**n - n * z ** (n - 2)
def _p_G(n, N):
"""Return the polynomial basis representation of the Green's polynomial `g`."""
g = _gtilde(n, z)
res = [g.subs(z, 0)]
for n in range(1, N + 1):
res.append(_Coefficient(g, _ptilde(n, z)))
return res
def _A2(N):
"""Return the change of basis matrix A2. The columns of the **inverse** of this matrix are given by `p_G`."""
res = zeros(N + 1, N + 1)
for n in range(N + 1):
res[n] = _p_G(n, N)
return res.inv()
def _A(N):
"""Return the full change of basis matrix."""
return _A2(N) * _A1(N)
[docs]
def generate_change_of_basis_matrix(N):
"""Generate the change of basis matrix to convert limb darking u coefficients to
Green's basis coefficients.
This function is only run once per system, though the resulting matrix is used
repeatedly in the light curve calculation. It implements Eq. 17 of `Agol, Luger, and
Foreman-Mackey 2020 <https://ui.adsabs.harvard.edu/abs/2020AJ....159..123A/abstract>`_.
Args:
N (int): The order of the polynomial limb darkening law.
Returns:
Array:
The change of basis matrix. When solving for the blocked flux, this will
be multiplied by the u limb darkening coefficients to convert them to the
Green's basis.
"""
m = _A(N)
return jnp.array(m, dtype=jnp.float64)
# Forked from jaxoplanet.core.limb_dark.py- this is another way to do it
# that goes around needing to store the change of basis matrix
# @jax.jit
# def greens_basis_transform(u: Array) -> Array:
# dtype = jnp.dtype(u)
# u = jnp.concatenate((-jnp.ones(1, dtype=dtype), u))
# size = len(u)
# i = np.arange(size)
# arg = binom(i[None, :], i[:, None]) @ u
# p = (-1) ** (i + 1) * arg
# g = [jnp.zeros((), dtype=dtype) for _ in range(size + 2)]
# for n in range(size - 1, 1, -1):
# g[n] = p[n] / (n + 2) + g[n + 2]
# g[1] = p[1] + 3 * g[3]
# g[0] = p[0] + 2 * g[2]
# return jnp.stack(g[:-2])