This notebook shows how Geodesic Discance Transform (GDT) using Fast Sweeping Method (FSM) can be implemented using Python. GDT can be thought of as a computation of shortest pathes on an image plane with a given cost function. It can be used as a part of more complex image processing tasks such as segmentation or coloring .
The description of different GDT construction methods can be found in .
# some preparations first %load_ext cythonmagic from scipy.misc import lena import scipy.ndimage as nd import itertools as it figsize(10, 8)
We use image gradient magnitude to costruct our cost function.
img = np.float32(lena()) / 255.0 G = nd.gaussian_gradient_magnitude(img, 1.0) Cost = 1 + G*200 imshow(Cost) colorbar() _=title('Cost function')
sweep function is the core of the algorithm. It propogates shortest path wavefront in right-down direction. We use Cython to accelerate pixel iteration.
%%cython from libc.math cimport sqrt def sweep(float[:,:] A, float[:,:] Cost): cdef int i, j cdef float t0, t1, t2, C, max_diff = 0.0 for i in xrange(1, A.shape): for j in xrange(1, A.shape): t1, t2 = A[i, j-1], A[i-1, j] C = Cost[i, j] if abs(t1-t2) > C: t0 = min(t1, t2) + C # handle degenerate case else: t0 = 0.5*(t1 + t2 + sqrt(2*C**2 - (t1-t2)**2)) max_diff = max(max_diff, A[i, j] - t0) A[i, j] = min(A[i, j], t0) return max_diff
GDT function uses NumPy array slicing to sweep the wavefront in four possible directions.
def GDT(A, C): A = A.copy() sweeps = [A, A[:,::-1], A[::-1], A[::-1,::-1]] costs = [C, C[:,::-1], C[::-1], C[::-1,::-1]] for i, (a, c) in enumerate(it.cycle(zip(sweeps, costs))): print i, if sweep(a, c) < 1.0 or i >= 40: break return A
A = zeros_like(Cost) # create distance accumulation array A[:] = 1e5 # fill it with large values to mark # cells with unknown distance A[300, 300] = 0 # set the source A = GDT(A, Cost)
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
Here is the contour plot of the GDT with the given cost function. Note, how isolines tend to follow image edges.
contour(A, 50, origin='image') _=axis('image')
This post in generated from IPython Notebook, which can be found in my GitHub repository