Skip to content

Pairwise

d(x,y) = \sum_{i=1}^N (x_i - y_i)^2

Vector Representation

def sqeuclidean_distances(x: np.ndarray, y: np.ndarray) -> float:
    return np.sum( (x - y) ** 2)

Numpy Implementation

d(x,y) = np.sqrt(np.dot(x, x) - 2.0 * np.dot(x, y) + np.dot(y, y))

Einsum

XX = np.einsum("ik,ik->i", x, x)
YY = np.einsum("ik,ik->i", y, y)
XY = np.einsum("ik,jk->ij", x, y)

if not square:
    dists = np.sqrt(XX[:, np.newaxis] + YY[np.newaxis, :] - 2*XY)
else:
    dists = XX[:, np.newaxis] + YY[np.newaxis, :] - 2*XY

Dot Products

XX = np.dot(x, x)
YY = np.dot(y, y)
XY = np.dot(x, y)

if not square:
    dists = np.sqrt(XX + YY - 2*XY)
else:
    dists = XX + YY - 2*XY

Pairwise Distances

dists = jit(vmap(vmap(partial(dist, **arg), in_axes=(None, 0)), in_axes=(0, None)))