Source code for oblate_system

import jax

jax.config.update("jax_enable_x64", True)
import copy
import pprint
from functools import partial

import jax.numpy as jnp
import matplotlib.pyplot as plt

from squishyplanet.engine.greens_basis_transform import generate_change_of_basis_matrix
from squishyplanet.engine.kepler import kepler, skypos, t0_to_t_peri
from squishyplanet.engine.parametric_ellipse import (
    poly_to_parametric,
    poly_to_parametric_helper,
)
from squishyplanet.engine.phase_curve_utils import (
    corrected_emission_profile,
    emission_phase_curve,
    extended_illumination_reflected_phase_curve,
    generate_sample_radii_thetas,
    lambertian_reflection,
    phase_curve,
    planet_surface_normal,
    pre_squish_transform,
    reflected_phase_curve,
    sample_surface,
    stellar_doppler_variations,
    stellar_ellipsoidal_variations,
    surface_star_cos_angle,
)
from squishyplanet.engine.planet_2d import planet_2d_coeffs
from squishyplanet.engine.planet_3d import (
    extended_illumination_offsets,
    planet_3d_coeffs,
    planet_3d_coeffs_extended_illumination,
)
from squishyplanet.engine.polynomial_limb_darkened_transit import (
    lightcurve as poly_lightcurve,
)
from squishyplanet.engine.polynomial_limb_darkened_transit import parameterize_2d_helper


[docs] class OblateSystem: """The core user interface for ``squishyplanet``, used to model potentially-triaxial exoplanet transits/phase curves. Note, all instances will have values associated with phase curve calculations, such as albedo and hotspot location. However, if inputs such are "compute_reflected_phase_curve" are set to ``False``, these values will not be used. The :func:`lightcurve` method will return only the transit light curve in this case, and should be used if computing a transit, reflected, or emitted phase curve. All arguments will be internally converted to ``jax.numpy`` dtypes, and all methods will similarly return ``jax.numpy`` arrays. These can be treated similarly to numpy arrays in most cases, but if passing outputs to external inference libraries expecting numpy, you may need to explicitly convert them. Properties: state (dict): A dictionary of all the parameters of the system, including those specified by the user, default values, and those calculated by combinations of the two. Immutable, but can be accessed to see the current state of the system. Args: times (array-like, [Days], default=None): The times at which to calculate the light curve. The gap between times is assumed to be in units of days, but any zero-point/standard system (e.g. BJD) will work. A required parameter, will raise an error if not provided. t_peri (float, [Days], default=None): The time of periastron passage. One of ``t_peri`` or ``t0`` must be provided. t0 (float, [Days], default=None): The time of transit center. One of ``t_peri`` or ``t0`` must be provided. period (float, [Days], default=None): The period of the orbit. A required parameter, will raise an error if not provided. a (float, [Rstar], default=None): The semi-major axis of the orbit in units of the radius of the star. A required parameter, will raise an error if not provided. tidally_locked (bool, default=None): Whether the planet is tidally locked to the star. If ``True``, then ``prec`` will always be set equal to the true anomaly, meaning the same face of the planet will always face the star. A required parameter, will raise an error if not provided. e (float, default=0.0): The eccentricity of the orbit. i (float, [Radian], default=jnp.pi / 2): The inclination of the orbit. Omega (float, [Radian], default=jnp.pi): The longitude of the ascending node. Changing this will **not affect** the transit light curve (more accurately, changes can always be compensated for via rotations in obliq or prec). It is included only because it naturally arises in the orbit rotations, and I guess could come into play if anyone ever wants to do a joint astrometry model. omega (float, [Radian], default=0.0): The argument of periapsis. Set to 0.0 for a circular orbit, otherwise there will be degeneracies with t_peri. obliq (float, [Radian], default=0.0): The obliquity of the planet. This is the angle between the planet's rotation axis and the normal to the orbital plane. It is defined such that a planet on a circular orbit with :math:`\\Omega = 0` and :math:`\\nu = 0` (i.e., when it's along the positive :math:`x` axis) will have its north pole tipped *away* from the star. prec (float, [Radian], default=0.0): The "precession angle" of the planet. This defined as a rotation of the planet about an axis that's aligned with its orbit normal and runs through the center of the planet (e.g., if obliq=0, it would set the planet's instantaneous rotational phase, and if obliq :math:`\\neq` 0, it would set the "season" of the northern hemisphere at periastron passage.) r (float, [Rstar], default=None): The equatorial radius of the planet. This will always be the largest of the 3 axes of the triaxial ellipsoid. Either this or the entire set of ``projected_effective_r``, ``projected_f``, and ``projected_theta`` must be provided. f1 (float, [Dimensionless], default=0.0): The fractional difference between the (longest) equatorial and polar radii of the planet. This is defined as :math:`(R_{eq} - R_{pol}) / R_{eq}`. f2 (float, [Dimensionless], default=0.0): The fractional difference between long and short radii of the ellipse that defines the equator of the planet. Defined similarly to f1. ld_u_coeffs (array-like, default=jnp.array([0.0, 0.0])): The coefficients that determine the limb darkening profile of the star. The star is assumed to be azimuthally symmetric and have a radial profile described by: .. math:: \\frac{I(\\mu)}{I_0} = - \\Sigma_{i=0}^N u_i (1 - \\mu)^i for some order polynomial :math:`N`. See `Agol, Luger, and Foreman-Mackey 2020 <https://ui.adsabs.harvard.edu/abs/2020AJ....159..123A/abstract>`_ for more. hotspot_latitude (float, [Radian], default=0.0): The latitude of a potential hotspot on the planet. This is defined according to the "physics" convention of spherical coordinates, not in the geography sense: 0 is the north pole, :math:`\\pi/2` is the equator, and :math:`\\pi` is the south pole. hotspot_longitude (float, [Radian], default=0.0): The longitude of a potential hotspot on the planet. hotspot_concentration (float, default=0.2): The "concentration" of the hotspot. This is the :math:`\\kappa` parameter in the von Mises-Fisher distribution that describes the hotspot. albedo (float, default=1.0): The (spatialy uniform) albedo of the planet. This is the fraction of light that is reflected, though the directional-dependent scattering is dictated by Lambert's cosine law. emitted_scale (float, default=1e-5): A scale factor that sets the amplitude of the the emitted flux of the planet. systematic_trend_coeffs (array-like, default=jnp.array([0.0,0.0])): The coefficients that determine the polynomial trend in time added to the lightcurves. Used to optionally model long-term drifts in observed data. log_jitter (float, default=-jnp.inf): The log of the "jitter" term included in likelihood calculations. The jitter is added in quadrature to the provided uncertainties to account for any unmodeled noise in the data. This value is the *standard deviation* of the jitter, not the variance. If set to -jnp.inf, the jitter term will not affect the likelihood. projected_effective_r (float, [Rstar], default=0.0): The radius of a circle with the same area is the projected ellipse. This is only relevant if ``parameterize_with_projected_ellipse`` is set to ``True``, which will override ``r``, ``f1``, ``f2``, ``obliq``, and ``prec``. projected_f (float, [Dimensionless], default=0.0): The flattening value of the projected ellipse. This is only relevant if ``parameterize_with_projected_ellipse`` is set to ``True``, which will override ``r``, ``f1``, ``f2``, ``obliq``, and ``prec``. projected_theta (float, [Radian], default=0.0): The angle of the semi-major axis of the projected ellipse. This is only relevant if ``parameterize_with_projected_ellipse`` is set to ``True``, which will override ``r``, ``f1``, ``f2``, ``obliq``, and ``prec``. extended_illumination_npts (int, default=1): The number of points used to sample the star's projected disk as seen by the planet when calculating the reflected flux and accounting for the star's extended size. Closely follows the implementation in ``starry``, specifically Sec. 4.1 of `Luger et al. 2022 <https://ui.adsabs.harvard.edu/abs/2022AJ....164....4L/abstract>`_. IMPLEMENTATION IS INCOMPLETE, WILL RAISE A NOTIMPLEMENTEDERROR IF SET TO ANYTHING OTHER THAN 1. compute_reflected_phase_curve (bool, default=False): Whether to include flux reflected by the planet when calling :func:`lightcurve`. compute_emitted_phase_curve (bool, default=False): Whether to include flux emitted by the planet when calling :func:`lightcurve`. compute_stellar_ellipsoidal_variations (bool, default=False): Whether to include stellar ellipsoidal variations in the light curve. This is the effect of the star's shape changing due to the gravitational pull of the planet, and here is modeled as a simple sinusoidal variation with 4 peaks per orbit. compute_stellar_doppler_variations (bool, default=False): Whether to include stellar doppler variations in the light curve. This captures the effects of the star's radial velocity changing and boosting the total flux/pushing some flux into/out of the bandpass of the observation. Here, it is modeled as a simple sinusoidal variation with 2 peaks per orbit. parameterize_with_projected_ellipse (bool, default=False): Whether to parameterize the planet as a projected ellipse rather than a triaxial ellipsoid. If ``True``, then ``projected_effective_r``, ``projected_f``, and ``projected_theta`` will be used. phase_curve_nsamples (int, default=50_000): The number of random samples of the planet's surface to draw when performing Monte Carlo estimates of the emitted/reflected flux. A larger number will increase the resolution/shrink the error of the estimate but result in longer computation times. random_seed (int, default=0): A random seed used for the Monte Carlo integrals in the phase curve. This feeds into ``jax.random.PRNGKey``. Runs with the same ``random_seed`` will always return identical outputs, so if checking the affect of altering ``phase_curve_nsamples``, you should change this as well. data (array-like, default=jnp.array([1.0])): The observed data to compare to the light curve. Must be the same length as ``times``. Only needed if calling :func:`loglike`. uncertainties (array-like, default=jnp.array([0.01])): The uncertainties on the observed data. Must be the same length as ``data``, even if the errors are homoskedastic. Only needed if calling :func:`loglike`. exposure_time (float, [Days], default=0.0): The length of each exposure in the light curve, used to correct for finite integration times if ``oversample`` is set to a value greater than 1. Important: the finite exposure time correction procedure assumes that the given times correspond to the **midpoints** of each exposure, not the *start* or *end*. No checks are made to ensure that it is shorter than the minimum time difference between the provided times. oversample (int, default=1): The factor by which to oversample the light curve to partially compensate for finite-time integrations. The overdense lightcurve is then binned down to the original provided times. See e.g. `Kipping 2010 <https://ui.adsabs.harvard.edu/abs/2010MNRAS.408.1758K/abstract>`_, "Binning is Sinning" for more. Must be a positive integer. Will be rounded up to nearest odd number. oversample_correction_order (int, default=2): After oversampling the light curve, how do you want to integrate over the exposure time to get the final binned light curve? This follows ``starry``'s treatment very closely: 0 is a centered Riemann sum like in Kipping 2010, 1 is a trapezoidal rule, and 2 is Simpson's rule. Must be one of those values. """ def __init__( self, times=None, t_peri=None, t0=None, period=None, a=None, tidally_locked=None, e=0.0, i=jnp.pi / 2, Omega=jnp.pi, omega=0.0, obliq=0.0, prec=0.0, r=None, f1=0.0, f2=0.0, ld_u_coeffs=jnp.array([0.0, 0.0]), hotspot_latitude=0.0, hotspot_longitude=0.0, hotspot_concentration=0.2, albedo=1.0, emitted_scale=1e-5, stellar_ellipsoidal_alpha=1e-6, stellar_doppler_alpha=1e-6, systematic_trend_coeffs=jnp.array([0.0, 0.0]), log_jitter=-jnp.inf, projected_effective_r=0.0, projected_f=0.0, projected_theta=0.0, extended_illumination_npts=1, compute_reflected_phase_curve=False, compute_emitted_phase_curve=False, compute_stellar_ellipsoidal_variations=False, compute_stellar_doppler_variations=False, parameterize_with_projected_ellipse=False, phase_curve_nsamples=50_000, random_seed=0, data=jnp.array([1.0]), uncertainties=jnp.array([jnp.inf]), exposure_time=0.0, oversample=1, oversample_correction_order=2, ): ####################################################################### # setup ####################################################################### state_keys = list(locals().keys()) state_keys.remove("self") state = {} for key in state_keys: state[key] = locals()[key] self._state = state self._validate_inputs() ####################################################################### # 1-time calculations ####################################################################### # necessary for all light curves self._state["greens_basis_transform"] = generate_change_of_basis_matrix( len(self._state["ld_u_coeffs"]) ) # for oversampling if self._state["oversample"] > 1: self._state["oversample"] += 1 - self._state["oversample"] % 2 self._state["stencil"] = jnp.ones(self._state["oversample"]) # Construct the exposure time integration stencil if self._state["oversample_correction_order"] == 0: dt = jnp.linspace(-0.5, 0.5, 2 * self._state["oversample"] + 1)[1:-1:2] elif self._state["oversample_correction_order"] == 1: dt = jnp.linspace(-0.5, 0.5, self._state["oversample"]) self._state["stencil"] = self._state["stencil"].at[1:-1].set(2) elif self._state["oversample_correction_order"] == 2: dt = jnp.linspace(-0.5, 0.5, self._state["oversample"]) self._state["stencil"] = self._state["stencil"].at[1:-1:2].set(4) self._state["stencil"] = self._state["stencil"].at[2:-1:2].set(2) self._state["stencil"] = self._state["stencil"] / jnp.sum( self._state["stencil"] ) dt = self._state["exposure_time"] * dt t = self._state["times"][:, None] + dt[None, :] t = t.reshape(-1) self._state["times"] = t else: self._state["times"] = self._state["times"] self._state["stencil"] = ( None # never used in this case, but to keep the state consistent ) # # for extended illumination reflection curves # actually, still having trouble with this, so setting aside for now-- # leaving it for a specific enhancement after 0.1.0 if self._state["compute_reflected_phase_curve"] > 1: raise NotImplementedError( "Extended illumination reflection curves are not yet implemented" ) # # based on starry._core.core.py's OpsReflected(OpsYlm) class # # create a grid of points uniformly distributed on the projected disk of the # # sta from an observer along the z-axis # N = int(2 + jnp.sqrt(self._state["extended_illumination_npts"] * 4 / jnp.pi)) # # note these points will be squished closer together during calculations to # # account for the a-dependent extent of the star from the planet's perspective # dx = jnp.linspace(-1 + 1e-12, 1 - 1e-12, N) # dx, dy = jnp.meshgrid(dx, dx) # # dz = jnp.sqrt(1 - dx**2 - dy**2) # dz = 1 - dx**2 - dy**2 # source_dx = dx[dz > 0].flatten() # source_dy = dy[dz > 0].flatten() # source_dz = dz[dz > 0].flatten() # pts = jnp.array([source_dx, source_dy, jnp.sqrt(source_dz)]).T # self._state["extended_illumination_points"] = pts # self._state["extended_illumination_npts"] = len(pts) # everything below here is just an instantaneous snapshot mostly for plotting, # these will all vary with different parameter inputs if self._state["t_peri"] is None: tp = t0_to_t_peri(**self._state) else: tp = self._state["t_peri"] time_deltas = self._state["times"] - tp mean_anomalies = 2 * jnp.pi * time_deltas / state["period"] true_anomalies = kepler(mean_anomalies, state["e"]) self._state["f"] = true_anomalies if self._state["tidally_locked"]: self._state["prec"] = self._state["f"] positions = skypos(**state) self._state["x_c"] = positions[0, :] self._state["y_c"] = positions[1, :] self._state["z_c"] = positions[2, :] if not self._state["parameterize_with_projected_ellipse"]: self._coeffs_3d = planet_3d_coeffs(**self._state) for key in self._coeffs_3d: if self._coeffs_3d[key].shape[0] == (): self._coeffs_3d[key] = jnp.array([self._coeffs_3d[key]]) self._coeffs_2d = planet_2d_coeffs(**self._coeffs_3d) self._para_coeffs_2d = poly_to_parametric(**self._coeffs_2d) r1, r2, _, _, cosa, sina = poly_to_parametric_helper(**self._coeffs_2d) area = jnp.pi * r1 * r2 effective_r = jnp.sqrt(area / jnp.pi) self._state["projected_effective_r"] = effective_r effective_theta = jnp.arctan(sina / cosa) effective_theta = jnp.where( effective_theta < 0, effective_theta + jnp.pi, effective_theta ) self._state["projected_theta"] = effective_theta effective_f = ( jnp.max(jnp.array([r1, r2])) - jnp.min(jnp.array([r1, r2])) ) / jnp.max(jnp.array([r1, r2])) self._state["projected_f"] = effective_f else: self._coeffs_3d = {} area = jnp.pi * self._state["projected_effective_r"] ** 2 r1 = jnp.sqrt(area / ((1 - self._state["projected_f"]) * jnp.pi)) r2 = r1 * (1 - self._state["projected_f"]) self._coeffs_2d, self._para_coeffs_2d = parameterize_2d_helper( projected_r=r1, projected_f=self._state["projected_f"], projected_theta=self._state["projected_theta"], xc=self._state["x_c"], yc=self._state["y_c"], ) self._lightcurve_fwd_grad_enforced = self._setup_lightcurve_func() self._loglike_fwd_grad_enforced = self._setup_loglike_func() def __repr__(self): s = pprint.pformat(self.state) return f"OblateSystem(\n{s}\n)" @property def state(self): """A dictionary that includes all of the parameters of the system. This is an immutable property, and will raise an error if you try to set it. If altering parameters that would affect a lightcurve, pass those as a dictionary to the :func:`lightcurve` method. If altering the data or times at which to generate the lightcurve, just define a new system with those values. Returns: dict: A dictionary of all the parameters of the system, including those specified by the user, default values, and those calculated by combinations of the two. """ # we internally changed "times" if oversample > 1, but we alwasy bin it back # down to the original times, so we can undo that expansion here s = copy.deepcopy(self._state) if s["oversample"] > 1: s["times"] = ( s["times"].reshape(-1, s["oversample"]) * s["stencil"][None, :] ).sum(axis=1) return s def _validate_inputs(self): for key, val in self._state.items(): if val is None: if (key == "r") & (self._state["parameterize_with_projected_ellipse"]): self._state["r"] = 0.0 continue if key == "t_peri": continue if key == "t0": continue raise ValueError(f"'{key}' is a required parameter") assert (self._state["t_peri"] is None) != ( self._state["t0"] is None ), "Exactly one of 't_peri' or 't0' must be specified" self._state["ld_u_coeffs"] = jnp.array(self._state["ld_u_coeffs"]) assert ( self._state["ld_u_coeffs"].shape[0] >= 2 ), "ld_u_coeffs must have at least 2 (even if higher-order terms are 0)" assert isinstance( self._state["phase_curve_nsamples"], int ), "phase_curve_nsamples must be an integer" assert isinstance( self._state["random_seed"], int ), "random_seed must be an integer" if self._state["e"] == 0: assert self._state["omega"] == 0, "omega must be 0 for a circular orbit" shapes = [] for key in self._state: if ( (key == "times") | (key == "ld_u_coeffs") | (key == "phase_curve_nsamples") | (key == "random_seed") | (key == "data") | (key == "uncertainties") | (key == "systematic_trend_coeffs") | (key == "exposure_time") | (key == "oversample") | (key == "oversample_correction_order") | (key == "extended_illumination_npts") ) or isinstance(self._state[key], bool): continue if isinstance(self._state[key], float | int): self._state[key] = jnp.array([self._state[key]]) shapes.append(1) else: if self._state[key] is None: continue # still one None hanging around, either t0 or t_peri if len(self._state[key].shape) > 1: raise ValueError( "All parameters must be scalars or 1D arrays of the same shape." ) if self._state[key].shape == (): self._state[key] = jnp.array([self._state[key]]) shapes.append(1) else: shapes.append(self._state[key].shape[0]) if len(jnp.unique(jnp.array(shapes))) > 2: raise ValueError( "All parameters must be scalars or arrays of the same shape." ) if self._state["parameterize_with_projected_ellipse"]: assert self._state["projected_effective_r"] > 0, ( "projected_effective_r must be greater than 0 if " "parameterize_with_projected_ellipse is True" ) assert not ( self._state["compute_reflected_phase_curve"] | self._state["compute_emitted_phase_curve"] | self._state["compute_stellar_ellipsoidal_variations"] | self._state["compute_stellar_doppler_variations"] ), ( "parameterize_with_projected_ellipse is incompatible with phase " "curve calculations" ) assert not self._state["tidally_locked"], ( "parameterize_with_projected_ellipse is incompatible with " "tidally_locked=True" ) assert (self._state["oversample_correction_order"] in [0, 1, 2]) & ( isinstance(self._state["oversample_correction_order"], int) ), "oversample_correction_order must be 0, 1, or 2" assert self._state["oversample"] > 0, "oversample must be greater than 0" if self._state["oversample"] > 1: assert ( self._state["exposure_time"] is not None ), "exposure_time must be provided if oversample > 1" if self._state["compute_stellar_ellipsoidal_variations"]: assert ( self._state["e"] == 0.0 ), "Stellar ellipsoidal variations are only valid for circular orbits" if self._state["compute_stellar_doppler_variations"]: assert ( self._state["e"] == 0.0 ), "Stellar doppler variations are only valid for circular orbits" def _setup_lightcurve_func(self): constants = { "tidally_locked": self._state["tidally_locked"], "compute_reflected_phase_curve": self._state[ "compute_reflected_phase_curve" ], "compute_emitted_phase_curve": self._state["compute_emitted_phase_curve"], "compute_stellar_ellipsoidal_variations": self._state[ "compute_stellar_ellipsoidal_variations" ], "compute_stellar_doppler_variations": self._state[ "compute_stellar_doppler_variations" ], "parameterize_with_projected_ellipse": self._state[ "parameterize_with_projected_ellipse" ], "oversample": self._state["oversample"], "random_seed": self._state["random_seed"], "phase_curve_nsamples": self._state["phase_curve_nsamples"], "extended_illumination_npts": self._state["extended_illumination_npts"], "state": self._state, } frozen = jax.tree_util.Partial(_lightcurve, **constants) @jax.custom_vjp def lightcurve(params): return frozen(params) def lightcurve_fwd(params): output = frozen(params) jac = jax.jacfwd(frozen)(params) return output, (jac,) def lightcurve_bwd(res, g): jac = res val = jax.tree.map(lambda x: x.T @ g, jac) return val lightcurve.defvjp(lightcurve_fwd, lightcurve_bwd) lightcurve = jax.jit(lightcurve) return lightcurve def _setup_loglike_func(self): constants = { "tidally_locked": self._state["tidally_locked"], "compute_reflected_phase_curve": self._state[ "compute_reflected_phase_curve" ], "compute_emitted_phase_curve": self._state["compute_emitted_phase_curve"], "compute_stellar_ellipsoidal_variations": self._state[ "compute_stellar_ellipsoidal_variations" ], "compute_stellar_doppler_variations": self._state[ "compute_stellar_doppler_variations" ], "parameterize_with_projected_ellipse": self._state[ "parameterize_with_projected_ellipse" ], "oversample": self._state["oversample"], "random_seed": self._state["random_seed"], "phase_curve_nsamples": self._state["phase_curve_nsamples"], "extended_illumination_npts": self._state["extended_illumination_npts"], "state": self._state, } frozen = jax.tree_util.Partial(_loglike, **constants) @jax.custom_vjp def loglike(params): return frozen(params) def loglike_fwd(params): output = frozen(params) jac = jax.jacfwd(frozen)(params) return output, jac def loglike_bwd(res, g): val = jax.tree.map(lambda x: x.T * g, res) return (val,) loglike.defvjp(loglike_fwd, loglike_bwd) loglike = jax.jit(loglike) return loglike def _illustrate_helper(self, times=None, true_anomalies=None, nsamples=50_000): if (times is not None) & (true_anomalies is not None): raise ValueError("Provide either times or true anomalies but not both") if times is not None: t_peri = self._state.get("t_peri", None) if t_peri is None: t_peri = t0_to_t_peri( e=self._state["e"], i=self._state["i"], omega=self._state["omega"], period=self._state["period"], t0=self._state["t0"], ) time_deltas = times - t_peri mean_anomalies = 2 * jnp.pi * time_deltas / self._state["period"] true_anomalies = kepler(mean_anomalies, self._state["e"]) elif true_anomalies is not None: pass else: true_anomalies = jnp.array([jnp.pi / 2]) if isinstance(true_anomalies, float | int): true_anomalies = jnp.array([true_anomalies]) # the trace of the orbit fs = jnp.linspace(0, 2 * jnp.pi, 300) orbit_positions = skypos( a=self._state["a"], e=self._state["e"], f=fs, Omega=self._state["Omega"], i=self._state["i"], omega=self._state["omega"], ) behind_star = ( (orbit_positions[0, :] ** 2 + orbit_positions[1, :] ** 2) < 1 ) & (orbit_positions[2, :] < 0) orbit_positions = orbit_positions.at[:, behind_star].set(jnp.nan) original_state = copy.deepcopy(self._state) original_3d_coeffs = copy.deepcopy(self._coeffs_3d) original_2d_coeffs = copy.deepcopy(self._coeffs_2d) original_para_coeffs_2d = copy.deepcopy(self._para_coeffs_2d) X_outline = [] Y_outline = [] Xs = [] Ys = [] Reflection = [] Emission = [] for i in range(len(true_anomalies)): # all of these could just be done in one go, # but bookkeeping was easier this way self._state["f"] = jnp.array([true_anomalies[i]]) if self._state["tidally_locked"]: self._state["prec"] = self._state["f"] # self._coeffs_3d = planet_3d_coeffs(**self._state) # self._coeffs_2d = planet_2d_coeffs(**self._coeffs_3d) # self._para_coeffs_2d = poly_to_parametric(**self._coeffs_2d) positions = skypos(**self._state) self._state["x_c"] = positions[0, :] self._state["y_c"] = positions[1, :] self._state["z_c"] = positions[2, :] if not self._state["parameterize_with_projected_ellipse"]: self._coeffs_3d = planet_3d_coeffs(**self._state) for key in self._coeffs_3d: if self._coeffs_3d[key].shape[0] == (): self._coeffs_3d[key] = jnp.array([self._coeffs_3d[key]]) self._coeffs_2d = planet_2d_coeffs(**self._coeffs_3d) self._para_coeffs_2d = poly_to_parametric(**self._coeffs_2d) else: self._coeffs_3d = {} area = jnp.pi * self._state["projected_effective_r"] ** 2 r1 = jnp.sqrt(area / ((1 - self._state["projected_f"]) * jnp.pi)) self._coeffs_2d, self._para_coeffs_2d = parameterize_2d_helper( projected_r=r1, projected_f=self._state["projected_f"], projected_theta=self._state["projected_theta"], xc=self._state["x_c"], yc=self._state["y_c"], ) # the boundary of the planet thetas = jnp.linspace(0, 2 * jnp.pi, 200) x_outline = ( self._para_coeffs_2d["c_x1"] * jnp.cos(thetas) + self._para_coeffs_2d["c_x2"] * jnp.sin(thetas) + self._para_coeffs_2d["c_x3"] ) y_outline = ( self._para_coeffs_2d["c_y1"] * jnp.cos(thetas) + self._para_coeffs_2d["c_y2"] * jnp.sin(thetas) + self._para_coeffs_2d["c_y3"] ) if not self._state["parameterize_with_projected_ellipse"]: # the phase curve bits sample_radii, sample_thetas = generate_sample_radii_thetas( jax.random.key(0), jnp.arange(nsamples) ) x, y, z = sample_surface( sample_radii, sample_thetas, **self._coeffs_2d, **self._coeffs_3d, ) # the reflected brightness profile normals = planet_surface_normal(x, y, z, **self._coeffs_3d) star_cos_ang = surface_star_cos_angle( normals, self._state["x_c"], self._state["y_c"], self._state["z_c"], ) reflection = lambertian_reflection(star_cos_ang, x, y, z) # the emitted brightness profile # need to take the first index since you aren't scanning here transform = pre_squish_transform(**self._state)[0] emission = corrected_emission_profile( x, y, z, transform, **self._state, ) behind_star = ((x**2 + y**2) < 1) & (z < 0) reflection = jnp.where(behind_star, jnp.nan, reflection) emission = jnp.where(behind_star, jnp.nan, emission) else: x = jnp.nan y = jnp.nan reflection = jnp.nan emission = jnp.nan X_outline.append(x_outline) Y_outline.append(y_outline) Xs.append(x) Ys.append(y) Reflection.append(reflection) Emission.append(emission) X_outline = jnp.array(X_outline) Y_outline = jnp.array(Y_outline) Xs = jnp.array(Xs) Ys = jnp.array(Ys) Reflection = jnp.array(Reflection) Emission = jnp.array(Emission) # behind_star = ((Xs ** 2 + Ys ** 2) < 1) # Reflection = jnp.where(Reflection == 0, jnp.nan, Reflection) # Emission = jnp.where(Emission == 0, jnp.nan, Emission) self._state = original_state self._coeffs_3d = original_3d_coeffs self._coeffs_2d = original_2d_coeffs self._para_coeffs_2d = original_para_coeffs_2d return { "orbit_positions": orbit_positions, "planet_x_outlines": X_outline, "planet_y_outlines": Y_outline, "sample_xs": Xs, "sample_ys": Ys, "reflected_intensity": Reflection, "emitted_intensity": Emission, }
[docs] def illustrate( self, times=None, true_anomalies=None, orbit=True, reflected=False, emitted=False, star_fill=True, window_size=0.4, star_centered=False, nsamples=50_000, figsize=(8, 8), ): """Visualize the layout of the system at one or more times. This method, if run in a jupyter notebook, will display a plot of some combination of the star, planet, and its orbit. It can color in the planet according to its reflected or emission profile, and the star according to its limb darkening profile. Helpful for checking the orientation of planet hotspots and/or its orientation after deformation. Args: times (array-like, [Days], default=None): The times at which to illustrate the system. The gap between times is assumed to be in units of days, but any zero-point/standard system (e.g. BJD) will work. Provide either this or ``true_anomalies`` but not both. true_anomalies (array-like, [Radian], default=None): The true anomalies at which to illustrate the system. Provide either this or ``times`` but not both. orbit (bool, default=True): Whether to plot a trace of the planet's orbital path reflected (bool, default=False): Whether to color in the planet according to its reflected flux profile. Can optionally include this or ``emitted`` but not both. emitted (bool, default=False): Whether to color in the planet according to its emitted flux profile. Can optionally include this or ``reflected`` but not both. star_fill (bool, default=True): Whether to color in the star according to its limb darkening profile. Note that the lowest color contour is bounded at zero, so if you have an unphysical limb darkening law where some radii are negative, those will appear as gaps (most often the contours will not reach the black outline of the star, which is always drawn). window_size (float, [Rstar], default=0.4): The size of the plotting window. The window will be centered on the mean position of the planet across all of the suggested times, unless ``star_centered`` is set to ``True``. star_centered (bool, default=False): Whether to center the plot on the star rather than the planet. nsamples (int, default=50_000): The number of random samples of the planet's surface to draw when illustrating the system. A larger number will increase the resolution of the plot but result in longer computation times. figsize (tuple, default=(8, 8)): The size of the figure to display. Passed directly to ``matplotlib.pyplot.subplots``. Returns: None: This method is used for its side effects of displaying a plot, not for its return value. """ if emitted: assert not reflected, "Can't illustrate both reflected and emitted flux" assert not self._state["parameterize_with_projected_ellipse"], ( "Can't illustrate emitted flux when only describing the 2D outline of" "the planet" ) if reflected: assert not emitted, "Can't illustrate both reflected and emitted flux" assert not self._state["parameterize_with_projected_ellipse"], ( "Can't illustrate reflected flux when only describing the 2D outline of" "the planet" ) fig, ax = plt.subplots(1, 1, figsize=(8, 8)) info = self._illustrate_helper( times=times, true_anomalies=true_anomalies, nsamples=nsamples ) if star_centered: im_center_x = 0 im_center_y = 0 else: im_center_x = jnp.mean(info["planet_x_outlines"]) im_center_y = jnp.mean(info["planet_y_outlines"]) star = plt.Circle((0, 0), 1, color="black", fill=False) ax.add_artist(star) if star_fill & (not jnp.all(self._state["ld_u_coeffs"] == 0)): # lifted from engine.polynomial_limb_darkened_transit u_coeffs = jnp.ones(self._state["ld_u_coeffs"].shape[0] + 1) * (-1) u_coeffs = u_coeffs.at[1:].set(self._state["ld_u_coeffs"]) g_coeffs = jnp.matmul(self._state["greens_basis_transform"], u_coeffs) normalization_constant = 1 / ( jnp.pi * (g_coeffs[0] + (2 / 3) * g_coeffs[1]) ) def _star_radial_profile(r): us = jnp.ones(self._state["ld_u_coeffs"].shape[0] + 1) * (-1) us = us.at[1:].set(self._state["ld_u_coeffs"]) mu = jnp.sqrt(1 - r**2) powers = jnp.arange(len(us)) return -jnp.sum(us * (1 - mu) ** powers) * normalization_constant X = jnp.linspace(-1, 1, 300) Y = jnp.linspace(-1, 1, 300) X, Y = jnp.meshgrid(X, Y) R = jnp.sqrt(X**2 + Y**2) Z = jax.vmap(_star_radial_profile)(R.flatten()).reshape(X.shape) min_val = jnp.max(jnp.array([0, jnp.nanmin(Z)])) max_val = jnp.nanmax(Z) ax.contourf( X, Y, Z, cmap="copper", levels=jnp.linspace(min_val, max_val, 20) ) elif star_fill: fill = plt.Circle((0, 0), 1, color="orange", fill=True, alpha=0.5) ax.add_artist(fill) if orbit: ax.plot( info["orbit_positions"][0, :], info["orbit_positions"][1, :], color="black", ls="--", lw=1, label="Orbit", ) for i in range(len(info["planet_x_outlines"])): ax.plot( info["planet_x_outlines"][i], info["planet_y_outlines"][i], color="black", lw=1, label="Planet", ) if reflected: ax.hexbin( info["sample_xs"][i], info["sample_ys"][i], info["reflected_intensity"][i], cmap="plasma", gridsize=100, mincnt=1, ) if emitted: ax.hexbin( info["sample_xs"][i], info["sample_ys"][i], info["emitted_intensity"][i], cmap="plasma", gridsize=100, mincnt=1, ) ax.set( aspect="equal", xlim=(im_center_x - window_size / 2, im_center_x + window_size / 2), ylim=(im_center_y - window_size / 2, im_center_y + window_size / 2), ) return
[docs] @staticmethod def fit_limb_darkening_profile(intensities, order=None, mus=None, rs=None): """Convert a stellar limb darkening profile to a polynomial representation. Given a set of stellar parameters, one can use a grid of stellar models to compute the limb darkening profile as a function of projected `r` or of `mu = sqrt(1 - r**2)`. These profiles are often then approximated with one of a few common limb darkening "laws", such as the quadratic or 4-parameter non-linear laws. Since `squishyplanet` only supports polynomial limb darkening profiles, but can support nearly arbitrary orders, we can approximate the profile with a polynomial. This is a convenience function for converting between a grid-derived profile and its best-fit polynomial representation in the correct basis for `squishyplanet`. Args: intensities (array-like): The relative intensities of the star at a given `mu` or `r`. order (int): The order of the polynomial to fit to the limb darkening profile. Note that in the `squishyplanet` basis, the polynomial is defined as `1 - sum_{i=1}^{order} u_i (1 - mu)^i`, so the number of coefficients is `order`, not `order+1`. mus (array-like, default=None): The `mu` values at which the intensities were computed. If `rs` is not provided, this is required. rs (array-like, default=None): The `r` values at which the intensities were computed. If `mus` is not provided, this is required. Returns: array-like: The coefficients of the polynomial representation of the limb darkening profile. These can then be used as the `ld_u_coeffs` parameter in `OblateSystem`. """ return _fit_limb_darkening_profile( intensities=intensities, order=order, mus=mus, rs=rs )
[docs] @staticmethod def limb_darkening_profile(ld_u_coeffs=None, r=None, mu=None): """Compute the limb darkening profile of the star at a given radius. Meant as a helper function for sanity checks and plotting, especially if you're using higher-order limb darkening laws and are concerned if the profile is positive/monotonic. Args: ld_u_coeffs (array-like, default=self.state["ld_u_coeffs"]): The coefficients of the polynomial limb darkening law. r (float or array-like, default=None): The radius at which to compute the limb darkening profile. Must be between 0 and 1. If provided, ``mu`` should be ``None``. mu (float or array-like, default=None): The cosine of the angle between the line of sight and the normal to the surface of the star. Must be between 0 and 1. If provided, ``r`` should be ``None``. Returns: Array: The limb darkening profile of the star at the given r or mu values. """ assert (mu is None) != (r is None), "Only one of `mu` or `r` should be provided" greens_transform = generate_change_of_basis_matrix(len(ld_u_coeffs)) if r is None: r = jnp.sqrt(1 - mu**2) u_coeffs = jnp.ones(ld_u_coeffs.shape[0] + 1) * (-1) u_coeffs = u_coeffs.at[1:].set(ld_u_coeffs) g_coeffs = jnp.matmul(greens_transform, u_coeffs) # total flux from the star. 1/eq. 28 in Agol, Luger, and Foreman-Mackey 2020 normalization_constant = 1 / (jnp.pi * (g_coeffs[0] + (2 / 3) * g_coeffs[1])) def inner(r): us = jnp.ones(ld_u_coeffs.shape[0] + 1) * (-1) us = us.at[1:].set(ld_u_coeffs) mu = jnp.sqrt(1 - r**2) powers = jnp.arange(len(us)) return -jnp.sum(us * (1 - mu) ** powers) * normalization_constant if isinstance(r, float | int): return inner(r) else: return jax.vmap(inner)(r)
[docs] def lightcurve(self, params={}): """Compute the light curve of the system. This method will return the light curve of the system at the times specified when the system was initialized. If you want to compute the light curve at different times, or with different orbital parameters, you can pass those parameters as a dictionary to this method. The first time this is run for a given system, JAX will jit-compile the function, which can take some time. Subsequent calls will be much faster unless you change the shape of any of the input arrays (e.g., changing the number of times or the order of the polynomial limb darkening law). In those cases, or if changing any of boolean flags, JAX will need to re-compile the function again. Args: params (dict, default={}): A dictionary of parameters to update in the system state. Any keys not provided will be pulled from the current state of the system. Returns: Array: The timeseries lightcurve of the system. The length will be equal to `state["times"]`, and each index corresponds to a time in that array. Examples: >>> state = { "t_peri" : 0.0, "times" : jnp.linspace(-jnp.pi, 2*jnp.pi, 3504), "a" : 2.0, "period" : 2*jnp.pi, "r" : 0.1, "compute_reflected_phase_curve" : True, "compute_emitted_phase_curve" : True, "emitted_scale" : 1e-5, } >>> system = OblateSystem(**state) >>> system.lightcurve() """ return self._lightcurve_fwd_grad_enforced(params)
[docs] def loglike(self, params={}): """Compute the log likelihood of the system given the observed data and some set of parameters. This method will call :func:`lightcurve` with the provided parameters and compare the output to the observed data. The likelihood is assumed to be Gaussian with no correlation between times. The jitter term is added in quadrature to the provided uncertainties. Args: params (dict, default={}): A dictionary of parameters to update in the system state. Any keys not provided will be pulled from the current state of the system. Returns: float: The log likelihood of the system given the observed data and the provided parameters. """ return self._loglike_fwd_grad_enforced(params)
@partial( jax.jit, static_argnums=( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ), ) def _lightcurve( params, tidally_locked, compute_reflected_phase_curve, compute_emitted_phase_curve, compute_stellar_ellipsoidal_variations, compute_stellar_doppler_variations, parameterize_with_projected_ellipse, oversample, random_seed, phase_curve_nsamples, extended_illumination_npts, state, ): # always compute the primary transit and trend for key in params: state[key] = params[key] transit = poly_lightcurve(state, parameterize_with_projected_ellipse) trend = jnp.polyval(state["systematic_trend_coeffs"], state["times"]) # if you don't want any phase curve stuff, you're done if (not compute_reflected_phase_curve) & (not compute_emitted_phase_curve) and ( not compute_stellar_doppler_variations ) & (not compute_stellar_ellipsoidal_variations): oversampled_curve = transit + trend if oversample > 1: c = ( oversampled_curve.reshape(-1, oversample) * state["stencil"][None, :] ).sum(axis=1) else: c = oversampled_curve return c ###################################################### # compute the planet's contribution to the phase curve ###################################################### # generate the radii and thetas that you'll reuse at each timestep sample_radii, sample_thetas = generate_sample_radii_thetas( jax.random.key(random_seed), jnp.arange(phase_curve_nsamples), ) # solve Kepler's equation to get the true anomalies time_deltas = state["times"] - state["t_peri"] mean_anomalies = 2 * jnp.pi * time_deltas / state["period"] true_anomalies = kepler(mean_anomalies, state["e"]) state["f"] = true_anomalies if tidally_locked: state["prec"] = state["f"] # technically these are all calculated in "transit", but since phase # curves are a) rare and b) expensive, we'll just do it again to keep # the transit section of the code more self-contained three = planet_3d_coeffs(**state) two = planet_2d_coeffs(**three) positions = skypos(**state) x_c = positions[0, :] y_c = positions[1, :] z_c = positions[2, :] # just the reflected component if compute_reflected_phase_curve & (not compute_emitted_phase_curve): if extended_illumination_npts == 1: reflected = reflected_phase_curve( sample_radii, sample_thetas, two, three, state, x_c, y_c, z_c ) else: offsets = extended_illumination_offsets(**state) three = planet_3d_coeffs_extended_illumination(**state, offsets=offsets) two = planet_2d_coeffs(**three) reflected = extended_illumination_reflected_phase_curve( sample_radii, sample_thetas, two, three, state, x_c, y_c, z_c, offsets ) emitted = 0.0 # just the emitted component elif (not compute_reflected_phase_curve) & compute_emitted_phase_curve: reflected = 0.0 emitted = emission_phase_curve(sample_radii, sample_thetas, two, three, state) # both reflected and emitted components. this function shares some of the # computation between the two, so it's a bit faster than running them separately elif ( compute_reflected_phase_curve & compute_emitted_phase_curve & (extended_illumination_npts == 1) ): reflected, emitted = phase_curve( sample_radii, sample_thetas, two, three, state, x_c, y_c, z_c ) elif ( compute_reflected_phase_curve & compute_emitted_phase_curve & (extended_illumination_npts != 1) ): if extended_illumination_npts == 1: reflected = reflected_phase_curve( sample_radii, sample_thetas, two, three, state, x_c, y_c, z_c ) else: offsets = extended_illumination_offsets(**state) three = planet_3d_coeffs_extended_illumination(**state, offsets=offsets) two = planet_2d_coeffs(**three) reflected = extended_illumination_reflected_phase_curve( sample_radii, sample_thetas, two, three, state, x_c, y_c, z_c, offsets ) emitted = emission_phase_curve(sample_radii, sample_thetas, two, three, state) else: reflected = 0.0 emitted = 0.0 #################################################### # compute the star's contribution to the phase curve #################################################### if compute_stellar_ellipsoidal_variations: ellipsoidal = stellar_ellipsoidal_variations( state["f"], state["stellar_ellipsoidal_alpha"], state["period"] ) else: ellipsoidal = 0.0 if compute_stellar_doppler_variations: doppler = stellar_doppler_variations( state["f"], state["stellar_doppler_alpha"], state["period"] ) else: doppler = 0.0 #################################################### # put it all together #################################################### oversampled_curve = transit + trend + reflected + emitted + ellipsoidal + doppler if oversample > 1: c = (oversampled_curve.reshape(-1, oversample) * state["stencil"][None, :]).sum( axis=1 ) else: c = oversampled_curve return c @partial( jax.jit, static_argnums=( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ), ) def _loglike( params, tidally_locked, compute_reflected_phase_curve, compute_emitted_phase_curve, compute_stellar_ellipsoidal_variations, compute_stellar_doppler_variations, parameterize_with_projected_ellipse, oversample, random_seed, phase_curve_nsamples, extended_illumination_npts, state, ): lc = _lightcurve( params, tidally_locked, compute_reflected_phase_curve, compute_emitted_phase_curve, compute_stellar_ellipsoidal_variations, compute_stellar_doppler_variations, parameterize_with_projected_ellipse, oversample, random_seed, phase_curve_nsamples, extended_illumination_npts, state, ) for key in params: state[key] = params[key] resids = state["data"] - lc var = jnp.exp(state["log_jitter"]) ** 2 + state["uncertainties"] ** 2 return jnp.sum(-0.5 * (resids**2 / var + jnp.log(2 * jnp.pi * var))) @partial(jax.jit, static_argnums=(1,)) def _fit_limb_darkening_profile(intensities, order=None, mus=None, rs=None): """Convert a stellar limb darkening profile to a polynomial representation. Given a set of stellar parameters, one can use a grid of stellar models to compute the limb darkening profile as a function of projected `r` or of `mu = sqrt(1 - r**2)`. These profiles are often then approximated with one of a few common limb darkening "laws", such as the quadratic or 4-parameter non-linear laws. Since `squishyplanet` only supports polynomial limb darkening profiles, but can support nearly arbitrary orders, we can approximate the profile with a polynomial. This is a convenience function for converting between a grid-derived profile and its best-fit polynomial representation in the correct basis for `squishyplanet`. Args: intensities (array-like): The relative intensities of the star at a given `mu` or `r`. order (int): The order of the polynomial to fit to the limb darkening profile. Note that in the `squishyplanet` basis, the polynomial is defined as `1 - sum_{i=1}^{order} u_i (1 - mu)^i`, so the number of coefficients is `order`, not `order+1`. mus (array-like, default=None): The `mu` values at which the intensities were computed. If `rs` is not provided, this is required. rs (array-like, default=None): The `r` values at which the intensities were computed. If `mus` is not provided, this is required. Returns: array-like: The coefficients of the polynomial representation of the limb darkening profile. These can then be used as the `ld_u_coeffs` parameter in `OblateSystem`. """ if rs is not None: mus = jnp.sqrt(1 - rs**2) powers = jnp.arange(order + 1)[1:] a = ((1 - mus) ** powers[:, None]).T b = intensities - 1 ld_u_coeffs = -jnp.linalg.lstsq(a, b)[0] return ld_u_coeffs