Source code for kepler

# this is a fork of jaxoplanet/src/jaxoplanet/core/kepler.py, many thanks to the original authors

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax.interpreters import ad


[docs] @jax.jit def kepler(M, ecc): """Solve Kepler's equation to compute the true anomaly. This implementation is based on that within `jaxoplanet <https://github.com/exoplanet-dev/jaxoplanet/>`_, many thanks to the authors. Args: M (Array [Radian]): Mean anomaly ecc (Array [Dimensionless]): Eccentricity Returns: Array: True anomaly in radians """ sinf, cosf = _kepler(M, ecc) # this is the only bit that's different from jaxoplanet- # puts true anomalies into the range [0, 2*pi) f = jnp.arctan2(sinf, cosf) return jnp.where(f < 0, f + 2 * jnp.pi, f)
@jax.custom_jvp def _kepler(M, ecc): # Wrap into the right range M = M % (2 * jnp.pi) # We can restrict to the range [0, pi) high = jnp.pi < M M = jnp.where(high, 2 * jnp.pi - M, M) # Solve ome = 1 - ecc E = _starter(M, ecc, ome) E = _refine(M, ecc, ome, E) # Re-wrap back into the full range E = jnp.where(high, 2 * jnp.pi - E, E) # Convert to true anomaly; tan(0.5 * f) tan_half_f = jnp.sqrt((1 + ecc) / (1 - ecc)) * jnp.tan(0.5 * E) tan2_half_f = jnp.square(tan_half_f) # Then we compute sin(f) and cos(f) using: # sin(f) = 2*tan(0.5*f)/(1 + tan(0.5*f)^2), and # cos(f) = (1 - tan(0.5*f)^2)/(1 + tan(0.5*f)^2) denom = 1 / (1 + tan2_half_f) sinf = 2 * tan_half_f * denom cosf = (1 - tan2_half_f) * denom return sinf, cosf @_kepler.defjvp def _(primals, tangents): M, e = primals M_dot, e_dot = tangents sinf, cosf = _kepler(M, e) # Pre-compute some things ecosf = e * cosf ome2 = 1 - e**2 def make_zero(tan): if type(tan) is ad.Zero: return ad.zeros_like_aval(tan.aval) else: return tan # Propagate the derivatives f_dot = make_zero(M_dot) * (1 + ecosf) ** 2 / ome2**1.5 f_dot += make_zero(e_dot) * (2 + ecosf) * sinf / ome2 return (sinf, cosf), (cosf * f_dot, -sinf * f_dot) def _starter(M, ecc, ome): M2 = jnp.square(M) alpha = 3 * jnp.pi / (jnp.pi - 6 / jnp.pi) alpha += 1.6 / (jnp.pi - 6 / jnp.pi) * (jnp.pi - M) / (1 + ecc) d = 3 * ome + alpha * ecc alphad = alpha * d r = (3 * alphad * (d - ome) + M2) * M q = 2 * alphad * ome - M2 q2 = jnp.square(q) w = jnp.square(jnp.cbrt(jnp.abs(r) + jnp.sqrt(q2 * q + r * r))) return (2 * r * w / (jnp.square(w) + w * q + q2) + M) / d def _refine(M, ecc, ome, E): sE = E - jnp.sin(E) cE = 1 - jnp.cos(E) f_0 = ecc * sE + E * ome - M f_1 = ecc * cE + ome f_2 = ecc * (E - sE) f_3 = 1 - f_1 d_3 = -f_0 / (f_1 - 0.5 * f_0 * f_2 / f_1) d_4 = -f_0 / (f_1 + 0.5 * d_3 * f_2 + (d_3 * d_3) * f_3 / 6) d_42 = d_4 * d_4 dE = -f_0 / (f_1 + 0.5 * d_4 * f_2 + d_4 * d_4 * f_3 / 6 - d_42 * d_4 * f_2 / 24) return E + dE def _x(a, e, f, Omega, i, omega): return ( a * (-1 + e**2) * ( jnp.sin(f) * ( jnp.cos(Omega) * jnp.sin(omega) + jnp.cos(i) * jnp.cos(omega) * jnp.sin(Omega) ) + jnp.cos(f) * ( -(jnp.cos(omega) * jnp.cos(Omega)) + jnp.cos(i) * jnp.sin(omega) * jnp.sin(Omega) ) ) ) / (1 + e * jnp.cos(f)) def _y(a, e, f, Omega, i, omega): return -( ( a * (-1 + e**2) * ( jnp.cos(i) * jnp.cos(Omega) * jnp.sin(f + omega) + jnp.cos(f + omega) * jnp.sin(Omega) ) ) / (1 + e * jnp.cos(f)) ) def _z(a, e, f, Omega, i, omega): return -((a * (-1 + e**2) * jnp.sin(i) * jnp.sin(f + omega)) / (1 + e * jnp.cos(f)))
[docs] @jax.jit def skypos(a, e, f, Omega, i, omega, **kwargs): """Compute the cartesian coordinates of the center of the planet in the sky frame given its orbital elements. Args: a (Array [Rstar]): Semi-major axis of the orbit in units of R_star e (Array [Dimensionless]): Eccentricity of the orbit f (Array [Radian]): True anomaly, the angle between the direction of periapsis and the current position of the planet as seen from the star. Omega (Array [Radian]): Longitude of the ascending node i (Array [Radian]): Orbital inclination omega (Array [Radian]): Argument of periapsis **kwargs: Unused additional keyword arguments. These are included so that we can can take in a larger state dictionary that includes all of the required parameters along with other unnecessary ones. Returns: Array: The cartesian coordinates of the planet in the sky frame. Shape [3, N]. """ return jnp.array( [ _x(a, e, f, Omega, i, omega), _y(a, e, f, Omega, i, omega), _z(a, e, f, Omega, i, omega), ] )
[docs] def true_anomaly_at_transit_center(e, i, omega): """Computes the true anomaly at the instant of minimum star/planet separation. Uses equations 4.12-4.18 of `Kipping 2011 <https://ui.adsabs.harvard.edu/abs/2011PhDT.......294K/abstract>`_ to compute the true anomaly at the instant of minimum star/planet separation. Args: e (Array [Dimensionless]): Eccentricity of the orbit i (Array [Radian]): Orbital inclination omega (Array [Radian]): Argument of periapsis Returns: Array: True anomaly at the instant of minimum star/planet separation in radians. """ hp = e * jnp.sin(omega) kp = e * jnp.cos(omega) eta_1 = (kp / (1 + hp)) * (jnp.cos(i) ** 2) eta_2 = (kp / (1 + hp)) * (1 / (1 + hp)) * (jnp.cos(i) ** 2) ** 2 eta_3 = ( -(kp / (1 + hp)) * ((-6 * (1 + hp) + kp**2 * (-1 + 2 * hp)) / (6 * (1 + hp) ** 3)) * (jnp.cos(i) ** 2) ** 3 ) eta_4 = ( -(kp / (1 + hp)) * ((-2 * (1 + hp) + kp**2 * (-1 + 3 * hp)) / (2 * (1 + hp) ** 4)) * (jnp.cos(i) ** 2) ** 4 ) eta_5 = ( (kp / (1 + hp)) * ( ( 40 * (1 + hp) ** 2 - 40 * kp**2 * (-1 + 3 * hp + 4 * hp**2) + kp**4 * (3 - 19 * hp + 8 * hp**2) ) / (40 * (1 + hp) ** 6) ) * (jnp.cos(i) ** 2) ** 5 ) eta_6 = ( (kp / (1 + hp)) * ( ( 24 * (1 + hp) ** 2 - 40 * kp**2 * (-1 + 4 * hp + 5 * hp**2) + 9 * kp**4 * (1 - 8 * hp + 5 * hp**2) ) / (24 * (1 + hp) ** 7) ) * (jnp.cos(i) ** 2) ** 6 ) return jnp.pi / 2 - omega - eta_1 - eta_2 - eta_3 - eta_4 - eta_5 - eta_6
[docs] def t0_to_t_peri(e, i, omega, period, t0, **kwargs): f = true_anomaly_at_transit_center(e, i, omega) eccentric_anomaly = jnp.arctan2(jnp.sqrt(1 - e**2) * jnp.sin(f), e + jnp.cos(f)) mean_anomaly = eccentric_anomaly - e * jnp.sin(eccentric_anomaly) return t0 - period / (2 * jnp.pi) * mean_anomaly