Source code for polynomial_limb_darkened_transit

import jax

jax.config.update("jax_enable_x64", True)
from functools import partial

import jax.numpy as jnp
from quadax import quadgk

from squishyplanet.engine.kepler import kepler, skypos, t0_to_t_peri
from squishyplanet.engine.parametric_ellipse import (
    cartesian_intersection_to_parametric_angle,
    poly_to_parametric,
)
from squishyplanet.engine.planet_2d import planet_2d_coeffs
from squishyplanet.engine.planet_3d import planet_3d_coeffs

epsabs = epsrel = 1e-12


def _t4(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00):
    return -1 + rho_00 - rho_x0 + rho_xx


def _t3(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00):
    return -2 * rho_xy + 2 * rho_y0


def _t2(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00):
    return -2 + 2 * rho_00 - 2 * rho_xx + 4 * rho_yy


def _t1(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00):
    return 2 * rho_xy + 2 * rho_y0


def _t0(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00):
    return -1 + rho_00 + rho_x0 + rho_xx


@jax.jit
def _single_intersection_points(
    rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00, **kwargs
):
    t4 = _t4(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
    t3 = _t3(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
    t2 = _t2(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
    t1 = _t1(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
    t0 = _t0(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)

    polys = jnp.array([t4, t3, t2, t1, t0])
    roots = jnp.roots(polys, strip_zeros=False)  # strip_zeros must be False to jit

    ts = jnp.where(jnp.imag(roots) == 0, jnp.real(roots), 999)
    xs = jnp.where(ts != 999, (1 - ts**2) / (1 + ts**2), ts)
    ys = jnp.where(ts != 999, 2 * ts / (1 + ts**2), ts)
    return xs, ys


# @jax.jit
# def multiple_intersection_points(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00, **kwargs):
#     t4 = _t4(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
#     t3 = _t3(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
#     t2 = _t2(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
#     t1 = _t1(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)
#     t0 = _t0(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00)

#     polys = jnp.array([t4, t3, t2, t1, t0]).T

#     roots = jax.vmap(lambda x: jnp.roots(x, strip_zeros=False))(polys)

#     ts = jnp.where(jnp.imag(roots) == 0, jnp.real(roots), 999)
#     xs = jnp.where(ts != 999, (1-ts**2)/(1+ts**2), ts)
#     ys = jnp.where(ts != 999, 2*ts/(1+ts**2), ts)
#     return xs, ys


[docs] @jax.jit def parameterize_2d_helper(projected_r, projected_f, projected_theta, xc, yc): """Convert from the alternative sky-projected parameterization to the same format used by the 3D parameterization. A good chunk of the code assumes that the planet's center is determined by the orbital elements and that it's outline is derived from an equatorial radius ``r``, a z-flattening ``f1``, a y-flattening ``f2``, and two body-centered rotations ``obliq`` and ``prec``. This are useful to have when working with phase curves that are sensitive to the actual 3D orientation of the planet, but when dealing with transits only, this parameterization is overkill and allows a bunch of degeneracies. So, if only doing transits, it is more convenient to parameterize the planet by its projected radius in the x and y directions, and the angle of the projected ellipse. This function takes in those parameters and returns the same dictionaries you'd get if you fed a full 3D parameterization into :func:`planet_2d.planet_2d_coeffs`. Args: projected_r (float): The projected "x" radius of the planet. projected_f (float): The flattening of the projected ellipse. projected_theta (float): The angle of the projected ellipse. Returns: tuple: A tuple of two dictionaries. The first dictionary contains the coefficients of the quadratic equation that describes the projected ellipse. The second dictionary contains coefficients that describe the parametric form of that same ellipse. """ # projected_ projected_r2 = projected_r * (1 - projected_f) cos_t = jnp.cos(projected_theta) sin_t = jnp.sin(projected_theta) two = { "rho_xx": cos_t**2 / projected_r**2 + sin_t**2 / projected_r2**2, "rho_xy": (2 * cos_t * sin_t) / projected_r**2 - (2 * cos_t * sin_t) / projected_r2**2, "rho_x0": (-2 * cos_t**2 * xc) / projected_r**2 - (2 * cos_t * yc * sin_t) / projected_r**2 + (2 * cos_t * yc * sin_t) / projected_r2**2 - (2 * xc * sin_t**2) / projected_r2**2, "rho_yy": cos_t**2 / projected_r2**2 + sin_t**2 / projected_r**2, "rho_y0": (-2 * cos_t**2 * yc) / projected_r2**2 - (2 * cos_t * xc * sin_t) / projected_r**2 + (2 * cos_t * xc * sin_t) / projected_r2**2 - (2 * yc * sin_t**2) / projected_r**2, "rho_00": (cos_t**2 * xc**2) / projected_r**2 + (cos_t**2 * yc**2) / projected_r2**2 + (2 * cos_t * xc * yc * sin_t) / projected_r**2 - (2 * cos_t * xc * yc * sin_t) / projected_r2**2 + (yc**2 * sin_t**2) / projected_r**2 + (xc**2 * sin_t**2) / projected_r2**2, } # two alternative takes cooked up during the Fortran implementation- # check later for numerical implications # projected_r_sq = projected_r * projected_r # projected_r2_sq = projected_r2 * projected_r2 # cos_t_sq = cos_t * cos_t # sin_t_sq = sin_t * sin_t # xc_sq = xc * xc # yc_sq = yc * yc # two = { # "rho_xx": cos_t_sq / projected_r_sq + sin_t_sq / projected_r2_sq, # "rho_xy": (2 * cos_t * sin_t) / projected_r_sq # - (2 * cos_t * sin_t) / projected_r2_sq, # "rho_x0": ( # ((-2 * cos_t_sq * xc) - (2 * cos_t * yc * sin_t)) / projected_r_sq # + ((2 * cos_t * yc * sin_t) - (2 * xc * sin_t_sq)) / projected_r2_sq # ), # "rho_yy": cos_t_sq / projected_r2_sq + sin_t_sq / projected_r_sq, # "rho_y0": ( # (- (2 * cos_t * xc * sin_t) - (2 * yc * sin_t_sq)) / projected_r_sq + # ((-2 * cos_t_sq * yc) + (2 * cos_t * xc * sin_t)) / projected_r2_sq # ), # "rho_00": ( # ((cos_t_sq * xc_sq) + (2 * cos_t * xc * yc * sin_t) + (yc_sq * sin_t_sq)) / projected_r_sq + # ((cos_t_sq * yc_sq) - (2 * cos_t * xc * yc * sin_t) + (xc_sq * sin_t_sq)) / projected_r2_sq # ) # } # two = {} # two["rho_xx"] = cos_t_sq / projected_r_sq + sin_t_sq / projected_r2_sq # tmp1 = 2.0 * cos_t * sin_t # two["rho_xy"] = tmp1 / projected_r_sq - tmp1 / projected_r2_sq # tmp1 = 2.0 * cos_t * yc * sin_t # two["rho_x0"] = ((-2.0 * cos_t_sq * xc) - tmp1) / projected_r_sq + ( # tmp1 - (2.0 * xc * sin_t_sq) # ) / projected_r2_sq # two["rho_yy"] = cos_t_sq / projected_r2_sq + sin_t_sq / projected_r_sq # tmp1 = 2.0 * cos_t * xc * sin_t # two["rho_y0"] = ((-2.0 * cos_t_sq * yc) + (tmp1)) / projected_r2_sq - ( # (2.0 * cos_t * xc * sin_t) + (2.0 * yc * sin_t_sq) # ) / projected_r_sq # tmp1 = 2.0 * cos_t * xc * yc * sin_t # two["rho_00"] = ( # (cos_t_sq * xc_sq) + (tmp1) + (yc_sq * sin_t_sq) # ) / projected_r_sq + ( # (cos_t_sq * yc_sq) - (tmp1) + (xc_sq * sin_t_sq) # ) / projected_r2_sq para = poly_to_parametric(**two) return two, para
[docs] @jax.jit def planet_solution_vec(a, b, g_coeffs, c_x1, c_x2, c_x3, c_y1, c_y2, c_y3): """Compute the "solution vector" for a 1D path across the star that lies on the outline of the planet. This computes Eq. 21 of `Agol, Luger, and Foreman-Mackey 2020 <https://ui.adsabs.harvard.edu/abs/2020AJ....159..123A/abstract>`_. But, instead of doing it analytically, this uses the ``quadax`` package to numerically solve the required integrals. For terms s_2 and higher, this is straightforward to do based on the equations in the paper: we simply parameterize the outline of the planet by some angle :math:`\\alpha`, then numerically integrate the dot product of Eq. 62 with that parameterization between the two endpoints of the path. For the first two lower-order terms however, Agol et al. do not provide an equivalent of Eq. 62 and instead provide only the analytic solutions. We therefore use the following as the equivalents for Eq. 62 for these terms: .. math:: G_0 = \\{0, x\\} .. math:: G_1 = \\left\\{0, \\frac{1}{2} \\left(x \\sqrt{-x^2-y^2+1}-\\left(y^2-1\\right) \\tan ^{-1}\\left(\\frac{x}{\\sqrt{-x^2-y^2+1}}\\right)\\right)+\\frac{\\pi }{12} \\right\\} These expressions were derived by solving the required PDE in Eq. 14 with the boundary conditions from Eq. 27. Finally, the C coefficients here describe the parametric form of the planet's outline as seen by the observer, and they satisfy: .. math:: x = c_{x1} \\cos(\\alpha) + c_{x2} \\sin(\\alpha) + c_{x3} y = c_{y1} \\cos(\\alpha) + c_{y2} \\sin(\\alpha) + c_{y3} for some angle :math:`\\alpha \\in [0, 2\\pi)`. Args: a (float): The starting parameter for the path along the planet's outline, :math:`\\alpha_0`. b (float): The ending parameter for the path along the planet's outline, :math:`\\alpha_1`. g_coeffs (Array): The system-specific limb darkening coefficients in the Green's basis. Computed by multiplying the u coefficients with the change of basis matrix from :func:`greens_basis_transform.generate_change_of_basis_matrix`. c_x1 (float): The first coefficient of the parametric 2D outline of the planet. c_x2 (float): The second coefficient of the parametric 2D outline of the planet. c_x3 (float): The third coefficient of the parametric 2D outline of the planet. c_y1 (float): The fourth coefficient of the parametric 2D outline of the planet. c_y2 (float): The fifth coefficient of the parametric 2D outline of the planet. c_y3 (float): The sixth coefficient of the parametric 2D outline of the planet. Returns: Array: The solution vector for the path along the planet's outline. The shape will match that of the input ``g_coeffs``. """ def s0_integrand(s): return (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) * ( -(jnp.sin(s) * c_y1) + jnp.cos(s) * c_y2 ) def s1_integrand(s): return ( (-(jnp.sin(s) * c_y1) + jnp.cos(s) * c_y2) * ( jnp.pi + 6 * (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) * jnp.sqrt( 1 - (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) ** 2 - (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2 ) - 6 * jnp.arctan( (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) / jnp.sqrt( 1 - (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) ** 2 - (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2 ) ) * (-1 + (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2) ) ) / 12.0 def sn_integrand(s): def scan_func(carry, scan_over): n = scan_over integrand = -( ( 1 - (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) ** 2 - (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2 ) ** (n / 2.0) * ( c_x3 * (jnp.sin(s) * c_y1 - jnp.cos(s) * c_y2) + c_x2 * (c_y1 + jnp.cos(s) * c_y3) - c_x1 * (c_y2 + jnp.sin(s) * c_y3) ) ) return None, integrand integrand = jax.lax.scan(scan_func, None, jnp.arange(g_coeffs.shape[0])[2:])[1] return integrand higher_terms, _ = quadgk( sn_integrand, jnp.array([a, b]), epsabs=epsabs, epsrel=epsrel ) s0, _ = quadgk(s0_integrand, jnp.array([a, b]), epsabs=epsabs, epsrel=epsrel) s1, _ = quadgk(s1_integrand, jnp.array([a, b]), epsabs=epsabs, epsrel=epsrel) s0 = jnp.array( [s0] ) # needed b/c when scanning over individual phases, this will return a scalar s1 = jnp.array([s1]) return jnp.concatenate((s0, s1, higher_terms))
[docs] @jax.jit def star_solution_vec(a, b, g_coeffs, c_x1, c_x2, c_x3, c_y1, c_y2, c_y3): """Compute the "solution vector" for a 1D path across the star that lies on the edge of the star itself. This is equivalent to :func:`planet_solution_vec`, but instead of integrating over paths that lie on the planet's outline, we integrate over paths that lie on the edge of the star. As pointed out in the paragraph following Eq. 69 in `Agol, Luger, and Foreman-Mackey 2020 <https://ui.adsabs.harvard.edu/abs/2020AJ....159..123A/abstract>`_, the contribution of all terms higher than :math:`G_1` will be zero in this case since we have limited ourselves to :math:`z=0` by remaining on the star's boundary. This simplifies things somewhat, though we do still have to numerically integrate the dot product of the parametric form of the star's outline with the :math:`G_0` and :math:`G_1` terms written out in :func:`planet_solution_vec`. Technically we probably could use the analytic solutions for these terms, but so far we have not. Args: a (float): The starting parameter for the path along the star's outline, :math:`\\alpha_0`. Note: here :math:`\\alpha` is the angle parameterizing the path on the *planet's* outline, not the star's, even though the path we will integrate over is on the star. We convert to the relevant parameters internally. b (float): The ending parameter for the path along the star's outline, :math:`\\alpha_1`. g_coeffs (Array): The system-specific limb darkening coefficients in the Green's basis. Computed by multiplying the u coefficients with the change of basis matrix from :func:`greens_basis_transform.generate_change_of_basis_matrix`. c_x1 (float): The first coefficient of the parametric 2D outline of the planet. c_x2 (float): The second coefficient of the parametric 2D outline of the planet. c_x3 (float): The third coefficient of the parametric 2D outline of the planet. c_y1 (float): The fourth coefficient of the parametric 2D outline of the planet. c_y2 (float): The fifth coefficient of the parametric 2D outline of the planet. c_y3 (float): The sixth coefficient of the parametric 2D outline of the planet. Returns: Array: The solution vector for the path along the planet's outline. The shape will match that of the input ``g_coeffs``. """ x1 = c_x1 * jnp.cos(a) + c_x2 * jnp.sin(a) + c_x3 y1 = c_y1 * jnp.cos(a) + c_y2 * jnp.sin(a) + c_y3 _theta1 = jnp.arctan2(y1, x1) _theta1 = jnp.where(_theta1 < 0, _theta1 + 2 * jnp.pi, _theta1) x2 = c_x1 * jnp.cos(b) + c_x2 * jnp.sin(b) + c_x3 y2 = c_y1 * jnp.cos(b) + c_y2 * jnp.sin(b) + c_y3 _theta2 = jnp.arctan2(y2, x2) _theta2 = jnp.where(_theta2 < 0, _theta2 + 2 * jnp.pi, _theta2) theta1 = jnp.where(_theta1 < _theta2, _theta1, _theta2) theta2 = jnp.where(_theta1 < _theta2, _theta2, _theta1) delta = jnp.abs(jnp.arctan2(jnp.sin(theta1 - theta2), jnp.cos(theta1 - theta2))) delta = theta2 - theta1 def s0_integrand(t): return jnp.cos(t) ** 2 def s1_integrand(t): return jnp.where( (t < jnp.pi / 2) | (t > 3 * jnp.pi / 2), (jnp.pi * jnp.cos(t) * (5 + 3 * jnp.cos(2 * t))) / 24.0, -(jnp.pi * jnp.cos(t) * (1 + 3 * jnp.cos(2 * t))) / 24.0, ) def no_wrap(delta): s0, _ = quadgk( s0_integrand, jnp.array([theta1, theta2]), epsabs=epsabs, epsrel=epsrel, ) s1, _ = quadgk( s1_integrand, jnp.array([theta1, theta2]), epsabs=epsabs, epsrel=epsrel, ) return s0, s1 def wrap(delta): s0, _ = quadgk( s0_integrand, jnp.array([theta2, 2 * jnp.pi]), epsabs=epsabs, epsrel=epsrel, ) s0 += quadgk( s0_integrand, jnp.array([0, theta1]), epsabs=epsabs, epsrel=epsrel )[0] s1, _ = quadgk( s1_integrand, jnp.array([theta2, 2 * jnp.pi]), epsabs=epsabs, epsrel=epsrel, ) s1 += quadgk( s1_integrand, jnp.array([0, theta1]), epsabs=epsabs, epsrel=epsrel )[0] return s0, s1 s0, s1 = jax.lax.cond(delta < jnp.pi, no_wrap, wrap, delta) solution_vec = jnp.zeros(g_coeffs.shape[0]) solution_vec = solution_vec.at[0].set(s0) solution_vec = solution_vec.at[1].set(s1) return solution_vec
[docs] @partial(jax.jit, static_argnames=("parameterize_with_projected_ellipse",)) def lightcurve(state, parameterize_with_projected_ellipse): """The main function for computing a transit light curve. This function will return a 1-D array representing the flux recieved from the star, where each entry corresponds to a time in the input `state` dictionary. It first transforms the `state` into the implicit 3D surface of the planet, the implicit 2D sky-projected outline of the planet, and a parametric form of that outline for each time step. These are vectorized operations that are computed simulataneously across all times. It then solves for the intersection points of the planet and star, and if the planet is either partially or fully transiting, numerically solves the required 1D integrals that leverage Green's Theorem to compute the blocked flux. The flux-blocking calculations are done sequentially for each timestep using ``jax.lax.scan``, which seemed to be more efficient than vectorizing again while switching between braches with something like ``jax.lax.cond``. Keep these different behaviors in mind when computing dense lightcurves with ~100s of thousands of time steps: the first part will require enough memory to compute and store ~30 values for each step, but then the actual 1D integrals will be computed sequentially. Args: state (dict): A dictionary containing all of the keys that are included in an :func:`OblateSystem` ``state`` attribute. parameterize_with_projected_ellipse (bool): If ``True``, the planet's outline will be parameterized by the projected ellipse as seen by the observer. If ``False``, the planet's outline will be set by the full 3D parameterization of the planet. When dealing with planets that are not tidally locked and/or far from their host star and/or very close to spherical, you won't be able to tell the difference between these two parameterizations since the projected area won't be changing. In that case, it's better to use the simpler 2D parameterization to avoid the degeneracies and extra computation that can arise from the 3D parameterization. This argument is static for the JIT-compiled function. Returns: Array: The flux received from the star at each time step for the times included as ``state["times"]``. """ # array we'll modify if the planet is in transit fluxes = jnp.ones_like(state["times"]) if state["t0"] is not None: state["t_peri"] = t0_to_t_peri(**state) time_deltas = state["times"] - state["t_peri"] mean_anomalies = 2 * jnp.pi * time_deltas / state["period"] true_anomalies = kepler(mean_anomalies, state["e"]) state["f"] = true_anomalies # convert the u coefficients to g coefficients u_coeffs = jnp.ones(state["ld_u_coeffs"].shape[0] + 1) * (-1) u_coeffs = u_coeffs.at[1:].set(state["ld_u_coeffs"]) g_coeffs = jnp.matmul(state["greens_basis_transform"], u_coeffs) # total flux from the star. 1/eq. 28 in Agol, Luger, and Foreman-Mackey 2020 # note, multiply, don't divide normalization_constant = 1 / (jnp.pi * (g_coeffs[0] + (2 / 3) * g_coeffs[1])) # cartesian position of the planet at each timestep positions = skypos(**state) if parameterize_with_projected_ellipse: area = jnp.pi * state["projected_effective_r"] ** 2 r1 = jnp.sqrt(area / ((1 - state["projected_f"]) * jnp.pi)) r2 = r1 * (1 - state["projected_f"]) two, para = parameterize_2d_helper( r1, state["projected_f"], state["projected_theta"], positions[0, :], positions[1, :], ) # force the shapes to match: if user inputs a scaler for one value but the # others are still (1,) there'd be a problem. all is fine if they're all # scalars or all arrays though r1 = jnp.ones_like(r2) * r1 largest_r = jnp.max(jnp.array([r1, r2])) else: state["prec"] = jnp.where(state["tidally_locked"], state["f"], state["prec"]) # the coefficients of the implicit 3d surface three = planet_3d_coeffs(**state) # the coefficients of the implicit 2d surface two = planet_2d_coeffs(**three) # the coefficients of the parametric projected ellipse para = poly_to_parametric(**two) largest_r = state["r"] possibly_in_transit = ( positions[0, :] ** 2 + positions[1, :] ** 2 <= (1.0 + largest_r * 1.1) ** 2 ) * (positions[2, :] > 0) def not_on_limb(X): para, _, _ = X def fully_transiting(para): solution_vectors = planet_solution_vec( a=0.0, b=2 * jnp.pi, g_coeffs=g_coeffs, **para ) blocked_flux = ( jnp.matmul(g_coeffs, solution_vectors) * normalization_constant ) return blocked_flux def not_transiting(para): return 0.0 return jax.lax.cond( para["c_x3"] ** 2 + para["c_y3"] ** 2 <= 1, fully_transiting, not_transiting, para, ) def partially_transiting(X): para, xs, ys = X alphas = cartesian_intersection_to_parametric_angle(xs, ys, **para) alphas = jnp.where(xs != 999, alphas, 2 * jnp.pi) alphas = jnp.where(alphas < 0, alphas + 2 * jnp.pi, alphas) alphas = jnp.where(alphas > 2 * jnp.pi, alphas - 2 * jnp.pi, alphas) alphas = jnp.sort(alphas) test_ang = alphas[0] + (alphas[1] - alphas[0]) / 2 test_ang = jnp.where(test_ang > 2 * jnp.pi, test_ang - 2 * jnp.pi, test_ang) _x = ( para["c_x1"] * jnp.cos(test_ang) + para["c_x2"] * jnp.sin(test_ang) + para["c_x3"] ) _y = ( para["c_y1"] * jnp.cos(test_ang) + para["c_y2"] * jnp.sin(test_ang) + para["c_y3"] ) test_val = jnp.sqrt(_x**2 + _y**2) def testval_inside_star(_): solution_vectors = planet_solution_vec( alphas[0], alphas[1], g_coeffs, **para ) planet_contribution = ( jnp.matmul(solution_vectors, g_coeffs) * normalization_constant ) return planet_contribution def testval_outside_star(_): leg1_solution_vec = planet_solution_vec( alphas[1], 2 * jnp.pi, g_coeffs, **para ) leg1 = jnp.matmul(leg1_solution_vec, g_coeffs) leg2_solution_vec = planet_solution_vec(0.0, alphas[0], g_coeffs, **para) leg2 = jnp.matmul(leg2_solution_vec, g_coeffs) planet_contribution = (leg1 + leg2) * normalization_constant return planet_contribution planet_contribution = jax.lax.cond( test_val > 1, testval_outside_star, testval_inside_star, () ) star_solution_vectors = star_solution_vec( alphas[0], alphas[1], g_coeffs, **para ) star_contribution = ( jnp.matmul(star_solution_vectors, g_coeffs) * normalization_constant ) total_blocked = planet_contribution + star_contribution return total_blocked def transiting(X): indv_para, indv_two = X ( xs, ys, ) = _single_intersection_points(**indv_two) on_limb = jnp.sum(xs) != 999 * 4 return jax.lax.cond( on_limb, partially_transiting, not_on_limb, (indv_para, xs, ys), ) def not_transiting(X): return 0.0 def scan_func(carry, scan_over): indv_para, indv_two, mask = scan_over return None, jax.lax.cond( mask, transiting, not_transiting, (indv_para, indv_two) ) # if prec isn't the same length as f, we've actually made # it to this point with some of the three, two vectors being # the same length as f, and the others are just scalars # if prec isn't changing, the planet's orientation isn't either, # so none of your quadratic terms vary if state["prec"].shape != true_anomalies.shape: two["rho_xx"] = jnp.ones_like(state["f"]) * two["rho_xx"] two["rho_xy"] = jnp.ones_like(state["f"]) * two["rho_xy"] two["rho_yy"] = jnp.ones_like(state["f"]) * two["rho_yy"] transit_fluxes = jax.lax.scan( scan_func, None, (para, two, possibly_in_transit), None )[1] return fluxes - transit_fluxes