Source code for parametric_ellipse

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp


[docs] @jax.jit def poly_to_parametric_helper(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00): """A helper function for :func:`poly_to_parametric`. Args: rho_xx (Array [Dimensionless]): Coefficient of x^2 rho_xy (Array [Dimensionless]): Coefficient of xy rho_x0 (Array [Dimensionless]): Coefficient of x rho_yy (Array [Dimensionless]): Coefficient of y^2 rho_y0 (Array [Dimensionless]): Coefficient of y rho_00 (Array [Dimensionless]): Constant term Returns: Tuple: - r1 (Array [Rstar]): Semi-major axis of the projected ellipse - r2 (Array [Rstar]): Semi-minor axis of the projected ellipse - xc (Array [Rstar]): x-coordinate of the center of the ellipse - yc (Array [Rstar]): y-coordinate of the center of the ellipse - cosa (Array [Dimensionless]): Cosine of the rotation angle - sina (Array [Dimensionless]): Sine of the rotation angle """ # (* base eq *) # pxx x^2 + pxy x y + px0 x + pyy y^2 + py0 y + p00 == 1 # (* normalize to get rid of p0 *) # pxx/(1 - p00) x^2 + pxy /(1 - p00) x y + px0/(1 - p00) x + # pyy/(1 - p00) y^2 + py0 /(1 - p00) y == 1 # (* solve for the ellipse center *) # CoefficientRules[ # pxx/(1 - p00) x^2 + pxy /(1 - p00) x y + px0/(1 - p00) x + # pyy/(1 - p00) y^2 + py0 /(1 - p00) y /. {x -> x - xc, # y -> y - yc}, {x, y}] # Solve[{px0/(1 - p00) - (2 pxx xc)/(1 - p00) - (pxy yc)/(1 - p00) == 0, # py0/(1 - p00) - (pxy xc)/(1 - p00) - (2 pyy yc)/(1 - p00) == # 0 }, {xc, yc}] # (* plug back in *) # Simplify[ # CoefficientRules[ # pxx/(1 - p00) x^2 + pxy /(1 - p00) x y + px0/(1 - p00) x + # pyy/(1 - p00) y^2 + py0 /(1 - p00) y /. {x -> x - xc, # y -> y - yc} /. {xc -> -((-pxy py0 + 2 px0 pyy)/( # pxy^2 - 4 pxx pyy)), # yc -> -((-px0 pxy + 2 pxx py0)/(pxy^2 - 4 pxx pyy))}, {x, y}]] # (* normalize again to get the final coeffs *) # pxxShift = # Simplify[(pxx/( # 1 - p00)) /(1 - ( # px0 pxy py0 - pxx py0^2 - # px0^2 pyy)/((-1 + p00) (pxy^2 - 4 pxx pyy)))] # pxyShift = # Simplify[(pxy/( # 1 - p00))/(1 - ( # px0 pxy py0 - pxx py0^2 - # px0^2 pyy)/((-1 + p00) (pxy^2 - 4 pxx pyy)))] # pyyShift = # Simplify[(pyy/( # 1 - p00)) /(1 - ( # px0 pxy py0 - pxx py0^2 - # px0^2 pyy)/((-1 + p00) (pxy^2 - 4 pxx pyy)))] # the center of the ellipse xc = (rho_xy * rho_y0 - 2 * rho_yy * rho_x0) / (4 * rho_xx * rho_yy - rho_xy**2) yc = (rho_xy * rho_x0 - 2 * rho_xx * rho_y0) / (4 * rho_xx * rho_yy - rho_xy**2) # get new coefficients for the centered ellipse: all others are zero now, # explicitly got rid of rho_00 so there's a lot more divison rho_xx_shift = -( (rho_xx * (rho_xy**2 - 4 * rho_xx * rho_yy)) / ( (-1 + rho_00) * rho_xy**2 - rho_x0 * rho_xy * rho_y0 + rho_x0**2 * rho_yy + rho_xx * (rho_y0**2 + 4 * rho_yy - 4 * rho_00 * rho_yy) ) ) rho_xy_shift = (-(rho_xy**3) + 4 * rho_xx * rho_xy * rho_yy) / ( (-1 + rho_00) * rho_xy**2 - rho_x0 * rho_xy * rho_y0 + rho_x0**2 * rho_yy + rho_xx * (rho_y0**2 + 4 * rho_yy - 4 * rho_00 * rho_yy) ) rho_yy_shift = -( (rho_yy * (rho_xy**2 - 4 * rho_xx * rho_yy)) / ( (-1 + rho_00) * rho_xy**2 - rho_x0 * rho_xy * rho_y0 + rho_x0**2 * rho_yy + rho_xx * (rho_y0**2 + 4 * rho_yy - 4 * rho_00 * rho_yy) ) ) # get the rotation angle (edge case gives you nans if there's no rotation) theta = jnp.where( rho_xx_shift - rho_yy_shift != 0.0, 0.5 * jnp.arctan2(rho_xy_shift, (rho_xx_shift - rho_yy_shift)) + jnp.pi / 2, 0.0, ) theta = jnp.where(theta < 0.0, theta + jnp.pi, theta) # jax.debug.print("{x}", x=theta) cosa = jnp.cos(theta) sina = jnp.sin(theta) # get the semi-major and semi-minor axes a = ( rho_xx_shift * jnp.cos(theta) ** 2 + rho_xy_shift * jnp.cos(theta) * jnp.sin(theta) + rho_yy_shift * jnp.sin(theta) ** 2 ) b = ( rho_xx_shift * jnp.sin(theta) ** 2 - rho_xy_shift * jnp.cos(theta) * jnp.sin(theta) + rho_yy_shift * jnp.cos(theta) ** 2 ) r1 = 1 / jnp.sqrt(a) r2 = 1 / jnp.sqrt(b) return r1, r2, xc, yc, cosa, sina
[docs] @jax.jit def poly_to_parametric(rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00): """Convert between the coefficients that describe an implicit to those defining a parametric ellipse. The input coefficients are those of the implicit ellipse equation: .. math:: \\rho_{xx} x^2 + \\rho_{xy} xy + \\rho_{x0} x + \\rho_{yy} y^2 + \\rho_{y0} y + \\rho_{00} = 1 Args: rho_xx (Array [Dimensionless]): Coefficient of x^2 rho_xy (Array [Dimensionless]): Coefficient of xy rho_x0 (Array [Dimensionless]): Coefficient of x rho_yy (Array [Dimensionless]): Coefficient of y^2 rho_y0 (Array [Dimensionless]): Coefficient of y rho_00 (Array [Dimensionless]): Constant term Returns: dict: Dictionary of coefficients for the parametric ellipse. The ellipse can now be described by the following parametric equations for parameter :math:`\\alpha`: .. math:: x = c_{x1} * \\cos(\\alpha) + c_{x2} * \\sin(\\alpha) + c_{x3} y = c_{y1} * \\cos(\\alpha) + c_{y2} * \\sin(\\alpha) + c_{y3} """ r1, r2, xc, yc, cosa, sina = poly_to_parametric_helper( rho_xx, rho_xy, rho_x0, rho_yy, rho_y0, rho_00 ) return { "c_x1": r1 * cosa, "c_x2": -r2 * sina, "c_x3": xc, "c_y1": r1 * sina, "c_y2": r2 * cosa, "c_y3": yc, }
[docs] @jax.jit def cartesian_intersection_to_parametric_angle( xs, ys, c_x1, c_x2, c_x3, c_y1, c_y2, c_y3, ): """Given a set of x and y coordinates corresponding to the intersection of the planet and star, compute the angle :math:`\\alpha` that corresponds to each point. Here, :math:`\\alpha` is the parameter in the parametric equations of the ellipse. See :func:`poly_to_parametric` for more details. Args: xs (Array [Rstar]): x-coordinates of the intersection points ys (Array [Rstar]): y-coordinates of the intersection points c_x1 (Array [Dimensionless]): Coefficient of x^2 c_x2 (Array [Dimensionless]): Coefficient of xy c_x3 (Array [Dimensionless]): Coefficient of x c_y1 (Array [Dimensionless]): Coefficient of y^2 c_y2 (Array [Dimensionless]): Coefficient of y c_y3 (Array [Dimensionless]): Constant term Returns: Array [Rstar]: The angle :math:`\\alpha` corresponding to each intersection point """ # center the ellipse xs -= c_x3 ys -= c_y3 # the x, y positions are now linear combinations of just cosa, sina # linear solve for those inv = jnp.linalg.inv(jnp.array([[c_x1, c_x2], [c_y1, c_y2]])) matrix = jax.vmap(lambda x, y: jnp.matmul(inv, jnp.array([x, y])))(xs, ys) cosa = matrix[:, 0] sina = matrix[:, 1] # convert to alpha alpha = jnp.arctan2(sina, cosa) return alpha