Fixed-Point Methods#

  • Efficient and Modular Implicit Differentiation - Blondel et al. (2021)


Psuedo-Code#

def f(x, params):

    return ...

Loops#

def fp_solver(x, params, iterations=100):

    model = model.merge(params)
    
    for _ in range(iterations):

        x = f(x, params)

    return x

Jax Scan#

def fp_solver(x, params, iterations=100):


    def body(x, i):

        x = f(x, params)

        return x, i

    # run solver
    x_phi, _ = jax.lax.scan(body, init=x, xs=None, length=fp_iters)

    return x_phi

Fixed-Point Solver#


def fp_projection_update_opt(params, model, x, y, mask, fp_fn, **kwargs):
    
    model = model.merge(params)
    
    def T(x, model):
        
        return fn_projection_update(model, x, y, mask)
    
    fpi = fp_fn(fixed_point_fun=T, **kwargs)
    
    sol = fpi.run(x, model)
    return sol.params