#!/usr/bin/python3

r'''Triangulate a feature in a pair of images to report a range

SYNOPSIS

  $ mrcal-triangulate
      --range-estimate 870
      left.cameramodel right.cameramodel
      left.jpg right.jpg
      1234 2234

  ## Feature [1234., 2234.] in the left image corresponds to [2917.9 1772.6] at 870.0m
  ## Feature match found at [2916.51891391 1771.86593517]
  ## q1 - q1_perfect_at_range = [-1.43873946 -0.76196924]
  ## Range: 677.699 m (error: -192.301 m)
  ## Reprojection error between triangulated point and q0: [5.62522473e-10 3.59364094e-09]pixels
  ## Observed-pixel sensitivity: 103.641m/pixel (q1). Worst direction: [1. 0.]. Linearized correction: 1.855 pixels
  ## Calibration yaw sensitivity: -3760.702m/deg. Linearized correction: -0.051 degrees of yaw
  ## Calibration pitch sensitivity: 0.059m/deg.
  ## Calibration translation sensitivity: 319.484m/m. Worst direction: [1 0 0].
  ## Linearized correction: 0.602 meters of translation
  ## Optimized yaw correction   = -0.03983 degrees
  ## Optimized pitch correction = 0.03255 degrees
  ## Optimized relative yaw (1 <- 0): -1.40137 degrees

Given a pair of images, a pair of camera models and a feature coordinate in the
first image, finds the corresponding feature in the second image, and reports
the range. This is similar to the stereo processing of a single pixel, but
reports much more diagnostic information than stereo tools do.

This is useful to evaluate ranging results.
'''

import sys
import argparse
import re
import os

def parse_args():

    def positive_float(string):
        try:
            value = float(string)
        except:
            raise argparse.ArgumentTypeError("argument MUST be a positive floating-point number. Got '{}'".format(string))
        if value <= 0:
            raise argparse.ArgumentTypeError("argument MUST be a positive floating-point number. Got '{}'".format(string))
        return value
    def positive_int(string):
        try:
            value = int(string)
        except:
            raise argparse.ArgumentTypeError("argument MUST be a positive integer. Got '{}'".format(string))
        if value <= 0 or abs(value-float(string)) > 1e-6:
            raise argparse.ArgumentTypeError("argument MUST be a positive integer. Got '{}'".format(string))
        return value

    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--make-corrected-model1',
                        action='store_true',
                        help='''If given, we assume the --range-estimate is correct, and output to standard
                        output a rotated camera1 to produce this range''')
    parser.add_argument('--template-size',
                        type=positive_int,
                        nargs=2,
                        default = (13,13),
                        help='''The size of the template used for feature
                        matching, in pixel coordinates of the second image. Two
                        arguments are required: width height. This is passed
                        directly to mrcal.match_feature(). We default to
                        13x13''')
    parser.add_argument('--search-radius',
                        type=positive_int,
                        default = 20,
                        help='''How far the feature-matching routine should
                        search, in pixel coordinates of the second image. This
                        should be larger if the range estimate is poor,
                        especially, at near ranges. This is passed directly to
                        mrcal.match_feature(). We default to 20 pixels''')
    parser.add_argument('--range-estimate',
                        type=positive_float,
                        default = 50,
                        help='''Initial estimate of the range of the observed feature. This is used for the
                        initial guess in the feature-matching. If omitted, I
                        pick 50m, completely arbitrarily''')
    parser.add_argument('--plane-n',
                        type=float,
                        nargs=3,
                        help='''If given, assume that we're looking at a plane in space, and take this into
                        account when matching the template. The normal vector to
                        this plane is given here, in camera-0 coordinates. The
                        normal does not need to be normalized; any scaling is
                        compensated in planed. The plane is all points p such
                        that inner(p,planen) = planed''')
    parser.add_argument('--plane-d',
                        type=float,
                        help='''If given, assume that we're looking at a plane in space, and take this into
                        account when matching the template. The
                        distance-along-the-normal to the plane, in from-camera
                        coordinates is given here. The plane is all points p
                        such that inner(p,planen) = planed''')
    parser.add_argument('--q-calibration-stdev',
                        type    = float,
                        help='''The uncertainty of the point observations at
                        calibration time. If given, we report the
                        analytically-computed effects of this noise. If a value
                        <0 is passed-in, we infer the calibration-time noise
                        from the optimal calibration-time residuals''')
    parser.add_argument('--q-observation-stdev',
                        type    = float,
                        help='''The uncertainty of the point observations at
                        observation time. If given, we report the
                        analytically-computed effects of this noise.''')
    parser.add_argument('--q-observation-stdev-correlation',
                        type    = float,
                        default = 0.0,
                        help='''By default, the noise in the observation-time
                        pixel observations is assumed independent. This isn't
                        entirely realistic: observations of the same feature in
                        multiple cameras originate from an imager correlation
                        operation, so they will have some amount of correlation.
                        If given, this argument specifies how much correlation.
                        This is a value in [0,1] scaling the stdev. 0 means
                        "independent" (the default). 1.0 means "100%%
                        correlated".''')
    parser.add_argument('--stabilize-coords',
                        action = 'store_true',
                        help='''If propagating calibration-time noise
                        (--q-calibration-stdev != 0), we report the uncertainty
                        ellipse in a stabilized coordinate system, compensating
                        for the camera-0 coord system motion''')
    parser.add_argument('--method',
                        choices=( 'geometric',
                                  'lindstrom',
                                  'leecivera-l1',
                                  'leecivera-linf',
                                  'leecivera-mid2',
                                  'leecivera-wmid2' ),
                        default = 'leecivera-mid2',
                        help='''The triangulation method. By default we use the
                        "Mid2" method from Lee-Civera's paper''')
    parser.add_argument('--corr-floor',
                        type=float,
                        default=0.9,
                        help='''This is used to reject mrcal.match_feature() results. The default is 0.9:
                        accept only very good matches. A lower threshold may
                        still result in usable matches, but do interactively
                        check the feature-matcher results by passing
                        "--viz match"''')
    parser.add_argument('--viz',
                        choices=('match', 'uncertainty'),
                        help='''If given, we visualize either the
                        feature-matcher results ("--viz match") or the
                        uncertainty ellipse(s) ("--viz uncertainty"). By
                        default, this produces an interactive gnuplot window.
                        The feature-match visualization shows 2 overlaid images:
                        the larger image being searched and the transformed
                        template, placed at its best-fitting location. Each
                        individual image can be hidden/shown by clicking on its
                        legend in the top-right of the plot. It's generally most
                        useful to show/hide the template to visually verify the
                        resulting alignment.''')
    parser.add_argument('--title',
                        type=str,
                        default = None,
                        help='''Title string for the --viz plot. Overrides the
                        default title. Exclusive with --extratitle''')
    parser.add_argument('--extratitle',
                        type=str,
                        default = None,
                        help='''Additional string for the --viz plot to append
                        to the default title. Exclusive with --title''')

    parser.add_argument('--hardcopy',
                        type=str,
                        help='''Write the --viz output to disk, instead of an interactive plot''')
    parser.add_argument('--terminal',
                        type=str,
                        help=r'''The gnuplotlib terminal used in --viz. The default is good almost always, so most people don't
                        need this option''')
    parser.add_argument('--set',
                        type=str,
                        action='append',
                        help='''Extra 'set' directives to gnuplotlib for --viz. Can be given multiple times''')
    parser.add_argument('--unset',
                        type=str,
                        action='append',
                        help='''Extra 'unset' directives to gnuplotlib --viz. Can be given multiple times''')
    parser.add_argument('--clahe',
                        action='store_true',
                        help='''If given, apply CLAHE equalization to the images
                        prior to the matching''')
    parser.add_argument('models',
                        type=str,
                        nargs = 2,
                        help='''Camera models for the images. Both intrinsics and extrinsics are used''')
    parser.add_argument('images_and_features',
                        type=str,
                        nargs=4,
                        help='''The images and/or feature pixes coordinates to
                        use for the triangulation. This is either IMAGE0 IMAGE1
                        FEATURE0X FEATURE0Y or FEATURE0X FEATURE0Y FEATURE1X
                        FEATURE1Y. If images are given, the given pixel is a
                        feature in image0, and we search for the corresponding
                        feature in image1.''')

    args = parser.parse_args()

    if args.viz is None:
        args.viz = ''

    if args.title      is not None and \
       args.extratitle is not None:
        print("--title and --extratitle are exclusive", file=sys.stderr)
        sys.exit(1)

    return args

args = parse_args()

def float_or_none(s):
    try:    f = float(s)
    except: f = None
    return f
features = [float_or_none(s) for s in args.images_and_features]


if (args.plane_n is     None and args.plane_d is not None) or \
   (args.plane_n is not None and args.plane_d is     None):
    print("--plane-n and --plane-d should both be given or neither should be", file=sys.stderr)
    sys.exit(1)

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README





import numpy as np

if   all(type(x) is float for x in features):
    q0 = np.array(features[:2], dtype=float)
    q1 = np.array(features[2:], dtype=float)
    images = None
elif all(type(x) is float for x in features[2:]):
    q0 = np.array(features[2:], dtype=float)
    q1 = None
    images = args.images_and_features[:2]
else:
    print(f"Need exactly 4 positional arguments: either IMAGE0 IMAGE1 FEATURE0X FEATURE0Y or FEATURE0X FEATURE0Y FEATURE1X FEATURE1Y",
          file=sys.stderr)
    sys.exit(1)


import numpysane as nps

import mrcal
import mrcal.utils
import scipy.optimize
import time


if   args.method == 'geometric':       method = mrcal.triangulate_geometric
elif args.method == 'lindstrom':       method = mrcal.triangulate_lindstrom
elif args.method == 'leecivera-l1':    method = mrcal.triangulate_leecivera_l1
elif args.method == 'leecivera-linf':  method = mrcal.triangulate_leecivera_linf
elif args.method == 'leecivera-mid2':  method = mrcal.triangulate_leecivera_mid2
elif args.method == 'leecivera-wmid2': method = mrcal.triangulate_leecivera_wmid2
else:
    raise Exception("Getting here is a bug")


def compute_H10_planes(models, q0, v0, plane_n, plane_d):

    r'''Report a homography for two cameras observing a point on a plane

    v0 is assumed normalized'''

    def get_R_abn(n):
        r'''Return a valid rotation matrix with the given n as the last column

        n is assumed to be a normal vector: nps.norm2(n) = 1

        Returns a rotation matrix where the 3rd column is the given vector n. The
        first two vectors are arbitrary, but are guaranteed to produce a valid
        rotation: RtR = I, det(R) = 1

        '''

        # arbitrary, but can't be exactly n
        a = np.array((1., 0, 0, ))
        proj = nps.inner(a, n)
        if abs(proj) > 0.8:
            # Any other orthogonal vector will do. If this projection was >
            # 0.8, I'm guaranteed that all others will be smaller
            a = np.array((0, 1., 0, ))
            proj = nps.inner(a, n)

        a -= proj*n
        a /= nps.mag(a)
        b = np.cross(n,a)
        return nps.transpose(nps.cat(a,b,n))




    Rt10 = mrcal.compose_Rt( models[1].extrinsics_Rt_fromref(),
                             models[0].extrinsics_Rt_toref() )
    R10 = Rt10[:3,:]
    t10 = Rt10[ 3,:]

    # The homography definition. Derived in many places. For instance in
    # "Motion and structure from motion in a piecewise planar environment"
    # by Olivier Faugeras, F. Lustman.
    H10 = plane_d * R10 + nps.outer(t10, plane_n)

    # Find the 3D point on the plane, and translate the homography to match them
    # up
    # inner(k*v0, n) = d -> k = d / inner(v0,n)
    k = plane_d / nps.inner(v0, plane_n)
    p0 = k * v0
    p1 = mrcal.transform_point_Rt(Rt10, p0)

    # I now compute the "projection matrix" for the two cameras at p. Such a
    # matrix can only fully describe a pinhole lens, but I can compute the local
    # estimate from the gradient. For an observation vector v = (vx vy 1) I want
    # q = (K*v)[:2]/(K*v)[2]. Let
    #
    #   K = [ K00 K02 ]
    #     = [ 0 0  1  ]
    #
    # Where K00 is a (2,2) matrix and K02 is a (2,1) matrix. So q = K00 vxy
    # + K02. K00 = dq/dvxy
    K0 = np.eye(3, dtype=float)
    _, dq0_dvxy0,_ = mrcal.project(p0/p0[2], *models[0].intrinsics(),
                                   get_gradients=True)
    K0[:2,:2] = dq0_dvxy0[:,:2]
    K0[:2, 2] = q0 - nps.inner(K0[:2,:2], p0[:2]/p0[2])

    # Now the same thing for K1
    K1 = np.eye(3, dtype=float)
    q1,dq1_dvxy1,_ = mrcal.project(p1/p1[2], *models[1].intrinsics(),
                                   get_gradients=True)
    K1[:2,:2] = dq1_dvxy1[:,:2]
    K1[:2, 2] = q1 - nps.inner(K1[:2,:2], p1[:2]/p1[2])

    H10 = nps.matmult( K1, H10, np.linalg.inv(K0) )

    return H10

    # ALTERNATIVE IMPLEMENTATION I MIGHT WANT LATER FOR OTHER PURPOSES: affine
    # transformation that has no projective behavior

    v0 = mrcal.unproject(q0, *models[0].intrinsics())

    # k*v0 is on the plane: inner(k*v0, n) = d -> k = d/inner(v0,n)
    p0 = v0 * (plane_d / nps.inner(v0, plane_n))

    # The coordinate in the plane are (u,v,n). (u,v) are in the plane, n is the
    # normal. p0 = R_cam0_uvn * p_uvn -> dp0/dp_uvn = R_cam0_uvn, dp0/dp_uv =
    # R_cam0_uv
    R_cam0_uv = get_R_abn(plane_n)[:,:2]

    _, dq0_dp0,_ = mrcal.project(p0, *models[0].intrinsics(),
                                 get_gradients=True)
    dq0_duv = nps.matmult(dq0_dp0, R_cam0_uv)

    # For the same (u,v,n) coord system we have p1 = R10 R_cam0_uvn * p_uvn +
    # t10 -> dp1/dp_uvn = R10 R_cam0_uvn, dp1/dp_uv = R10 R_cam0_uv
    p1 = mrcal.transform_point_Rt(Rt10, p0)
    q1,dq1_dp1,_ = mrcal.project(p1, *models[1].intrinsics(),
                                 get_gradients=True)
    dq1_duv = nps.matmult(dq1_dp1, Rt10[:3,:], R_cam0_uv)

    H10 = np.eye(3, dtype=float)
    nps.matmult2( dq1_duv,
                  np.linalg.inv(dq0_duv),
                  out = H10[:2,:2])

    H10[:2,2] = q1 - mrcal.apply_homography(H10, q0)

    return H10

def visualize_match(args,
                    match_feature_diagnostics,
                    match_feature_out):

    import gnuplotlib as gp

    if match_feature_diagnostics is None:
        print("## WARNING: no feature matching was performed, so not visualizing the feature-matching results")
    else:
        data_tuples, plot_options = match_feature_out[2:]

        # don't plot the correlation map (the last plot data tuple)
        gp.plot( *data_tuples[:-1], **plot_options,
                 wait = args.hardcopy is None)

def visualize_uncertainty(args,
                          Var_p_calibration,
                          Var_p_observation,
                          Var_p_joint,
                          Rt01, p0):

    import gnuplotlib as gp

    if Var_p_calibration is None and \
       Var_p_observation is None and \
       Var_p_joint       is None:

        print("## WARNING: don't have any ellipses, so not visualizing them")
        return



    def get__puv__R_uv_cam0(Rt01, p):
        # I remap the point to the epipolar-plane-centered coords:
        # - x: along the baseline
        # - y: forward to the point
        # - z: normal
        x  = Rt01[3,:].copy()
        x /= nps.mag(x)
        y  = p.copy()
        y -= nps.inner(x,y)*x
        y /= nps.mag(y)
        z  = np.cross(x,y)

        R_uvn_cam0 = nps.cat(x,y,z)
        p_rotated = mrcal.rotate_point_R(R_uvn_cam0, p)
        return p_rotated[:2], R_uvn_cam0[:2,:]

    # I want to plot the ellipse in uv coords. So I want to plot
    # Var(R_uvn_cam0*p_cam0) = R_uv_cam0 Var(p_cam0) nps.transpose(R_uv_cam0)
    p_uv,R_uv_cam0 = get__puv__R_uv_cam0(Rt01, p0)
    plotargs = []
    for what,var in ( ("Calibration-time", Var_p_calibration),
                      ("Observation-time", Var_p_observation),
                      ("Joint",            Var_p_joint)):
        if var is None: continue

        var = nps.matmult( R_uv_cam0,
                           var,
                           nps.transpose(R_uv_cam0) )

        plotargs.append( \
            mrcal.utils._plot_arg_covariance_ellipse(p_uv, var,
                                                     f"{what} noise") )

    plot_options = dict(square = True,
                        xlabel = 'Epipolar plane x (along the baseline) (m)',
                        ylabel = 'Epipolar plane y (forward/back) (m)',
                        title  = "Uncertainty propagation",
                        wait   = args.hardcopy is None)

    if args.title is not None:
        plot_options['title'] = args.title
    if args.extratitle is not None:
        plot_options['title'] += f": {args.extratitle}"

    gp.add_plot_option(plot_options,
                       _set       = args.set,
                       unset      = args.unset,
                       hardcopy   = args.hardcopy,
                       terminal   = args.terminal,
                       overwrite  = True)

    gp.plot(*plotargs, **plot_options)



np.set_printoptions(precision = 3,
                    suppress  = True)


if images is not None:
    if args.clahe:
        import cv2
        clahe = cv2.createCLAHE()
        clahe.setClipLimit(8)
    def imread(filename):
        image = mrcal.load_image(filename, bits_per_pixel = 8, channels = 1)
        if args.clahe:
            image = clahe.apply(image)
        return image
    images = [imread(f) for f in images]

models = [mrcal.cameramodel(m) for m in args.models]

if args.q_calibration_stdev is not None and \
   args.q_calibration_stdev != 0:

    optimization_inputs = models[0].optimization_inputs()

    if optimization_inputs is None:
        print("optimization_inputs are not available, so I cannot propagate calibration-time noise",
              file=sys.stderr)
        sys.exit(1)

    if not models[0]._optimization_inputs_match(models[1]):
        print("The optimization_inputs for both models must be identical",
              file=sys.stderr)
        sys.exit(1)

    for i in (0,1):
        if models[i]._extrinsics_moved_since_calibration():
            print(f"The given models must have been fixed inside the initial calibration. Model {i} has been moved",
                  file=sys.stderr)
            sys.exit(1)

print("## generated on {} with   {}\n".format(time.strftime("%Y-%m-%d %H:%M:%S"),
                                              ' '.join(mrcal.shellquote(s) for s in sys.argv)))

v0 = mrcal.unproject(q0, *models[0].intrinsics(), normalize = True)

Rt01 = mrcal.compose_Rt( models[0].extrinsics_Rt_fromref(),
                         models[1].extrinsics_Rt_toref() )

if q1 is not None:
    print(f"## Using user-supplied feature match at {q1}")
    match_feature_diagnostics = None

else:
    if args.plane_n is not None:
        H10 = compute_H10_planes(models,
                                 q0, v0,
                                 np.array(args.plane_n),
                                 args.plane_d)
    else:
        # no plane defined. So I make up a plane that's the given distance away,
        # viewed head-on.

        # I compute the "mean" forward view of the cameras

        # "forward" for cam0 in cam0 coords
        n0 = np.array((0,0,1.))
        # "forward" for cam1 in cam0 coords
        n1 = mrcal.rotate_point_R(Rt01[:3,:], np.array((0,0,1.)))
        n = n0+n1
        # normalized "mean forward" in cam0 coords
        n /= nps.mag(n)

        # inner(n,k*v) = d
        d = nps.inner(n, args.range_estimate*v0)
        H10 = compute_H10_planes(models,
                                 q0, v0,
                                 n, d)

    q1_perfect = mrcal.apply_homography(H10, q0)
    print(f"## Feature {q0} in the left image corresponds to {q1_perfect} at {args.range_estimate}m")

    plotkwargs_extra = dict()
    if args.title is not None:
        plotkwargs_extra['title'] = args.title
    if args.extratitle is not None:
        plotkwargs_extra['extratitle'] = args.extratitle

    match_feature_out = \
        mrcal.match_feature(*images,
                            q0               = q0,
                            H10              = H10,
                            search_radius1   = args.search_radius,
                            template_size1   = args.template_size,

                            visualize        = args.viz == 'match',
                            return_plot_args = True,
                            _set             = args.set,
                            unset            = args.unset,
                            hardcopy         = args.hardcopy,
                            terminal         = args.terminal,
                            **plotkwargs_extra)

    q1, match_feature_diagnostics = match_feature_out[:2]

    if q1 is None or \
       match_feature_diagnostics['matchoutput_optimum_subpixel'] < args.corr_floor:
        print("## Feature-matching failed! Maybe increase the --search-radius or reduce --corr-floor?")
        sys.exit(1)

    print(f"## Feature match found at {q1}")
    print(f"## q1 - q1_perfect_at_range = {q1 - q1_perfect}")

Var_p_calibration = None
Var_p_observation = None
Var_p_joint       = None
if args.q_calibration_stdev is None and \
   args.q_observation_stdev is None:
    p0 = \
        mrcal.triangulate( nps.cat(q0, q1),
                           models,
                           stabilize_coords = args.stabilize_coords,
                           method = method )
elif args.q_calibration_stdev is not None and \
     args.q_observation_stdev is None:
    p0,             \
    Var_p_calibration = \
        mrcal.triangulate( nps.cat(q0, q1),
                           models,
                           q_calibration_stdev = args.q_calibration_stdev,
                           stabilize_coords    = args.stabilize_coords,
                           method              = method )
elif args.q_calibration_stdev is None and \
     args.q_observation_stdev is not None:
    p0,             \
    Var_p_observation = \
        mrcal.triangulate( nps.cat(q0, q1),
                           models,
                           q_observation_stdev             = args.q_observation_stdev,
                           q_observation_stdev_correlation = args.q_observation_stdev_correlation,
                           stabilize_coords                = args.stabilize_coords,
                           method                          = method )
else:
    p0,             \
    Var_p_calibration,  \
    Var_p_observation,  \
    Var_p_joint = \
        mrcal.triangulate( nps.cat(q0, q1),
                           models,
                           q_calibration_stdev             = args.q_calibration_stdev,
                           q_observation_stdev             = args.q_observation_stdev,
                           q_observation_stdev_correlation = args.q_observation_stdev_correlation,
                           stabilize_coords                = args.stabilize_coords,
                           method                          = method )

if nps.norm2(p0) == 0:
    print("Observation rays divergent: feature is near infinity or a false match was found", file=sys.stderr)

    if args.viz == '':
        pass
    elif args.viz == 'match':
        visualize_match(args,
                        match_feature_diagnostics,
                        match_feature_out)
    else:
        raise Exception("Unknown --viz arg. This is a bug")
    sys.exit(1)

# I use the lower-level triangulation function further on, for the empirical
# sensitivity calculation. I call it here as a sanity check, to make sure that
# mrcal.triangulate() gave me the same thing
v1 = mrcal.unproject(q1, *models[1].intrinsics())

if nps.norm2( p0 - method( v0, v1,
                           Rt01        = Rt01,
                           v_are_local = True) ) > 1e-6:
    print("## WARNING: mrcal.triangulate() returned something different from the internal routine. This is a bug. The empirical sensitivities may be off")


# worst-case rotation axis will be the yaw. This axis is the normal to the plane
# containing the 2 cameras and the observed point
n_epipolar_plane = np.cross(p0, Rt01[3,:])
n_epipolar_plane /= nps.mag(n_epipolar_plane)
# I want n_epipolar_plane to point in the same direction as the cam0 y axis, since that's
# the nominal yaw axis
if n_epipolar_plane[1] < 0:
    n_epipolar_plane *= -1.

# I also want to look at the pitch: the rotation around the baseline vector
v_pitch = Rt01[3,:] / nps.mag(Rt01[3,:])


def get_rt10_perturbed( t01err            = np.array((0,0,0)),
                        r01err_yaw        = 0,
                        r01err_yaw_axis   = None,
                        r01err_pitch      = 0,
                        r01err_pitch_axis = None):

    # This is crude. I'm assuming that the two rotations are small, around
    # orthogonal axes, so they commute
    R = np.eye(3)
    if r01err_yaw:
        K    = mrcal.skew_symmetric(r01err_yaw_axis)
        Ryaw = np.eye(3) + np.sin(r01err_yaw) * K + (1. - np.cos(r01err_yaw))*nps.matmult(K,K)
        R = Ryaw
    if r01err_pitch:
        K = mrcal.skew_symmetric(r01err_pitch_axis)
        Rpitch = np.eye(3) + np.sin(r01err_pitch) * K + (1. - np.cos(r01err_pitch))*nps.matmult(K,K)
        R = nps.matmult(R, Rpitch)

    rt10 = mrcal.rt_from_Rt( mrcal.compose_Rt(mrcal.invert_Rt(Rt01),
                                              nps.glue(R, np.zeros((3,)), axis=-2)))
    rt10[3:] -= t01err
    return rt10

def get_world_intersection_point(qerr1 = np.array((0,0)),
                                 **kwargs):

    rt10 = get_rt10_perturbed(**kwargs)
    v1 = mrcal.unproject(q1 + qerr1, *models[1].intrinsics())

    return method( v0, v1,
                   Rt01        = mrcal.invert_Rt( mrcal.Rt_from_rt(rt10) ),
                   v_are_local = True)

def get_range(**kwargs):
    # divergent rays have p = (0,0,0)
    p = get_world_intersection_point(**kwargs)
    return nps.mag(p)



range0 = get_range()

if range0 == 0:
    # The initial geometry produces a triangulation behind the camera. This is a
    # qualitatively different solution, so I don't report linear extrapolation results
    print("## Initial geometry is divergent. Not reporting sensitivities")

else:

    range_err_have = range0 - args.range_estimate

    print(f"## Triangulated point at {p0}; direction: {v0}")
    print("## Range: {:.2f} m (error: {:.2f} m)".format(range0, range_err_have))
    print(f"## q0 - q0_triangulation = {q0 - mrcal.project(p0, *models[0].intrinsics())}")


    for what,var in ( ("calibration-time", Var_p_calibration),
                      ("observation-time", Var_p_observation),
                      ("joint",            Var_p_joint)):

        if var is not None:
            l,v = mrcal.utils._sorted_eig(var)
            # look at the least-confident direction
            sigma = np.sqrt(l[-1])
            v     = v[:,-1]
            print(f"## Uncertainty propagation: {what} noise suggests worst confidence of sigma={sigma:.2f}m along {v}")



    # half-assed sensitivity testing. Finite differences for each piece
    delta = 1e-3

    rangeerr_worst_qerr = 0
    for th in np.linspace(0,2.*np.pi, 90, endpoint=False):
        vq = np.array((np.cos(th),np.sin(th)))
        r = get_range(qerr1 = delta * vq)
        if r == 0:
            print("Sampled rays divergent", file=sys.stderr)
            sys.exit(1)
        rangeerr = np.abs(r - range0)
        if rangeerr > rangeerr_worst_qerr:
            rangeerr_worst_qerr = rangeerr
            vq_worst_err        = vq
    print("## Observed-pixel sensitivity: {:.2f}m/pixel (q1). Worst direction: {}. Linearized correction: {:.2f} pixels". \
          format(rangeerr_worst_qerr/delta,
                 vq_worst_err,
                 -range_err_have/rangeerr_worst_qerr*delta ))


    # worst-case rotation axis will be the yaw. This axis is the normal to the plane
    # containing the 2 cameras and the observed point
    delta = 1e-6

    r = get_range(r01err_yaw_axis = n_epipolar_plane,
                  r01err_yaw      = delta * np.pi/180.)
    if r == 0:
        print("Sampled rays divergent", file=sys.stderr)
        sys.exit(1)
    rangeerr_yaw = r - range0
    print("## Calibration yaw (rotation in epipolar plane) sensitivity: {:.2f}m/deg. Linearized correction: {:.3f} degrees of yaw". \
          format(rangeerr_yaw/delta,
                 -range_err_have/rangeerr_yaw*delta))

    r = get_range(r01err_yaw_axis = np.array((0., 1., 0.,)),
                  r01err_yaw      = delta * np.pi/180.)
    if r == 0:
        print("Sampled rays divergent", file=sys.stderr)
        sys.exit(1)
    rangeerr_yaw = r - range0
    print("## Calibration yaw (cam0 y axis)                sensitivity: {:.2f}m/deg. Linearized correction: {:.3f} degrees of yaw". \
          format(rangeerr_yaw/delta,
                 -range_err_have/rangeerr_yaw*delta))

    # I also want to look at the pitch
    r = get_range(r01err_pitch_axis = v_pitch,
                  r01err_pitch      = delta * np.pi/180.)
    if r == 0:
        print("Sampled rays divergent", file=sys.stderr)
        sys.exit(1)
    rangeerr_pitch = r - range0
    print("## Calibration pitch (tilt of epipolar plane) sensitivity: {:.2f}m/deg.".format(rangeerr_pitch/delta))



    delta = 1e-3
    rangeerr_worst_t01err = 0
    for th in np.linspace(0, 2.*np.pi, 40, endpoint=False):
        for ph in np.linspace(0, np.pi, 20):

            vt01 = np.array((np.cos(ph) * np.sin(th),
                             np.cos(ph) * np.cos(th),
                             np.sin(ph)))
            r = get_range(t01err = delta * vt01)
            if r == 0:
                print("Sampled rays divergent", file=sys.stderr)
                sys.exit(1)
            rangeerr = np.abs(r - range0)
            if rangeerr > rangeerr_worst_t01err:
                rangeerr_worst_t01err = rangeerr
                vt01_worst_err        = vt01

    print("## Calibration translation sensitivity: {:.2f}m/m. Worst direction: {}. Linearized correction: {:.2f} meters of translation". \
          format(rangeerr_worst_t01err/delta,
                 vt01_worst_err,
                 -range_err_have/rangeerr_worst_t01err*delta))


# I have a pixel coordinate given on the commandline and a range. I can use the
# range to correct the yaw (x shift in the image). And I can tweak the pitch to
# correct the y shift in the image.
def get_range_err_yaw_pitch_sq(dyawdpitch):

    p = get_world_intersection_point(r01err_yaw_axis   = n_epipolar_plane,
                                     r01err_yaw        = dyawdpitch[0],
                                     r01err_pitch_axis = v_pitch,
                                     r01err_pitch      = dyawdpitch[1])

    r = nps.mag(p)
    if r == 0:
        # The intersection is behind the camera. I arbitrarily make the error
        # very large
        return 1e6
    erange        = r - args.range_estimate
    ereprojection = mrcal.project(p, *models[0].intrinsics()) - q0

    # The range errors are generally much higher than pixel errors, so I scale
    # these. Nothing about any of this is particularly principled, but it is
    # useful for testing
    ereprojection *= 100
    return erange*erange + nps.norm2(ereprojection)

# If I'm divergent initially. I take discrete steps to make me
# convergent, then I run an iterative method
if get_range_err_yaw_pitch_sq((0,0)) < 1e6:
    # not divergent initially
    dyawdpitch = (0,0)
else:
    print("## Initially divergent. Finding a usable operating point to begin optimization")

    step = 1. * np.pi/180.0
    for d in np.array((( 1.,  0.),
                       (-1.,  0.),
                       ( 0.,  1.),
                       ( 0., -1.)))*step:

        i=0
        while i < 180:
            dyawdpitch = i*d
            err = get_range_err_yaw_pitch_sq(dyawdpitch)
            if err < 1e6:
                # no longer divergent
                break
            i += 1
        if err < 1e6:
            break
    if err >= 1e6:
        print("## Initial search couldn't make the rays converge. Will not report optimized corrections")
        dyawdpitch = None

if dyawdpitch is not None:
    dyaw,dpitch = scipy.optimize.minimize(get_range_err_yaw_pitch_sq, dyawdpitch).x
    print("## Optimized yaw   (rotation in epipolar plane) correction = {:.3f} degrees".format(dyaw/np.pi*180))
    print("## Optimized pitch (tilt of epipolar plane)     correction = {:.3f} degrees".format(dpitch/np.pi*180))

    rt10 = get_rt10_perturbed(r01err_yaw_axis   = n_epipolar_plane,
                              r01err_yaw        = dyaw,
                              r01err_pitch_axis = v_pitch,
                              r01err_pitch      = dpitch)
    print("## Optimized relative yaw (1 <- 0): {:.3f} degrees".format(rt10[1] * 180./np.pi))

if args.make_corrected_model1:
    if dyaw is None:
        print("I can't make the corrected model if I couldn't compute the yaw,pitch shifts", file=sys.stderr)
        sys.exit(1)
    rt_1r = mrcal.compose_rt(rt10,
                             models[0].extrinsics_rt_fromref())
    models[1].extrinsics_rt_fromref(rt_1r)
    models[1].write(sys.stdout)


if args.viz == '':
    pass
elif args.viz == 'match':
    visualize_match(args,
                    match_feature_diagnostics,
                    match_feature_out)
elif args.viz == 'uncertainty':
    visualize_uncertainty(args,
                          Var_p_calibration,
                          Var_p_observation,
                          Var_p_joint,
                          Rt01, p0)
else:
    raise Exception("Unknown --viz arg. This is a bug")
