Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function for single step of SLIC #8

Open
vishwa91 opened this issue Jan 8, 2020 · 8 comments
Open

Function for single step of SLIC #8

vishwa91 opened this issue Jan 8, 2020 · 8 comments

Comments

@vishwa91
Copy link

vishwa91 commented Jan 8, 2020

Would it be possible to get a function handle that performs one single step of SLIC update?

Specifically: Given the RGB (or LAB) image, and the centroids as input, the function should output the label map. This will be very useful if the centroids are constrained to be in a certain structure.

Thank you!

@Algy
Copy link
Owner

Algy commented Jan 9, 2020

You can modify coordinates of centroids by updating slic.slic_model.clusters to a new list where slic is an instance of Slic or SlicAvx2 class. The value is a list of a dict in which coordinate of position is stored in the key yx. Note that the value of the key 'rgb' actually indicates the tuple of LAB color multiplied by 2. Also, you should assign a new list to the property, not modifying the value in place.

@Algy
Copy link
Owner

Algy commented Jan 9, 2020

You can iterate only once by slic.iterate(img, 1).
However, there is one caveat to consider: fast-slic samples the portion of image in the row-basis. That is, when subsample_stride is 3, only one third of image rows are used for each iteration. So, iterating only once might not utilize the entire information of the image. If that matters to you, you might consider setting subsample_stride to 1 in the constructor of class Slic. Or, you can grow the number of iteration up to 3.
For more information on subsampling, please refer to my undergraduate paper uploaded in issue #7 .

@vishwa91
Copy link
Author

If I understand correctly, these are the steps I need to follow:

  1. Assign new centroids to slic.slic_mode.clusters (new list, not modified)
  2. Iterate thrice (if stride==3) with the same slic model as above.

@weiwenchuan
Copy link

Hi @vishwa91 , I have the same problem and would like to ask if you have solved it. I tried the solution mentioned by @Algy , however, the result is not as I expected. Please see the example below. I defined a few centroids (not evenly distributed), set subsample_stride=1 and run iteration only once. The centroids I set are labelled in picture 1. The slic labels are in picture 2. You can see that the segmentation is not following the defined centroids.
Screen Shot 2021-09-18 at 1 26 11 PM
I also tried setting the number of iteration as 0 and the result is below.
Screen Shot 2021-09-18 at 1 26 48 PM
I'm not sure if the function still runs several iterations when we set 1 or 0, or I did something wrong. @vishwa91 could you please kindly let me know if you know what the problem is? Any insight is appreciated. My code is as follows.

from fast_slic import Slic
from fast_slic.avx2 import SlicAvx2
from skimage import segmentation, color
import matplotlib.pyplot as plt
import cv2
import imutils

image = cv2.imread("fish.jpeg")
image = imutils.resize(image, width=300)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

slic = Slic(num_components=12, compactness=10, subsample_stride=1)
n_clusters = len(slic.slic_model.clusters)

# use pre-defined centroids
centers_w = [20, 70, 120, 170, 220, 270] * 2
centers_h = [40] * 6 + [80] * 6
new_clusters = slic.slic_model.clusters.copy()
for k in range(0, len(new_clusters)):
    cluster = new_clusters[k]
    cluster['yx'] = (centers_h[k], centers_w[k])
    cluster['number'] = k
slic.slic_model.clusters = new_clusters
print('after center assignment', slic.slic_model.clusters)

# iterate only once
slic_result = slic.iterate(image, 1)

# plot results
fig, ax_arr = plt.subplots(1, 2)
ax1, ax2 = ax_arr.ravel()

# show centers in image
for cluster in new_clusters:
    label = cluster['number']
    center = cluster['yx']
    cv2.circle(image, (int(center[1]), int(center[0])), 3, [255, 0, 255], -1)
    cv2.putText(image, str(label), (int(center[1]), int(center[0])), cv2.FONT_HERSHEY_PLAIN, 2, [255, 0, 255], 2)
ax1.imshow(segmentation.mark_boundaries(image, slic_result))

ax2.imshow(slic_result)
plt.show()

@vishwa91
Copy link
Author

@weiwenchuan I could not get @Algy code to work with my own centroids, so I wrote a partial SLIC update code below:

def slic_update(imrgb, mask, compactness=10.0):
    '''
        Performs one single step of SLIC to update membership

        Inputs:
            imrgb: RGB image
            mask: Sparse sampling mask / centroids of superpixels after
                sanitizing
            compactness: SLIC compactness parameter

        Outputs:
            L: superpixel membership map
            N: Total number of super pixels
    '''

    H, W, _ = imrgb.shape
    ch, cw = np.where(mask == 1)

    # Create LabXY image
    [Y, X] = np.mgrid[:H, :W]

    imlabxy = np.zeros((H, W, 5), dtype=np.float32)

    imlabxy[:, :, :3] = cv2.cvtColor(imrgb, cv2.COLOR_RGB2Lab)
    imlabxy[:, :, 3] = X
    imlabxy[:, :, 4] = Y

    # Reshape to a matrix
    imlabxymat = imlabxy.reshape(H*W, 5).astype(np.float32)

    centroids_labxy = imlabxy[ch, cw, :].astype(np.float32)
    N = ch.size
    S = int(np.sqrt(H*W/N))

    nmembers = np.zeros(N)

    dist_matrix = np.ones((H, W), dtype=np.float32)*float('inf')
    L = np.ones((H, W), dtype=np.uint16)

    # Inefficient, but just do it
    for idx in range(N):
        hmin = max(0, ch[idx] - 2*S); hmax = min(H, ch[idx] + 2*S)
        wmin = max(0, cw[idx] - 2*S); wmax = min(W, cw[idx] + 2*S)

        imlabxy_patch = imlabxy[hmin:hmax, wmin:wmax, :]
        dist_patch_old = dist_matrix[hmin:hmax, wmin:wmax]
        dist_patch = cassi_cp._get_dist_cp(centroids_labxy[idx, :], imlabxy_patch,
                                           np.float32(compactness), S)

        L_patch = L[hmin:hmax, wmin:wmax]
        L_patch[dist_patch < dist_patch_old] = idx

        L[hmin:hmax, wmin:wmax] = L_patch

        dist_matrix[hmin:hmax, wmin:wmax] = np.minimum(dist_patch,
                                                       dist_patch_old)

    return L.astype(np.uint16), N

The function cassi_cp._get_dist_cp() was written in cython:

import numpy as np
import cv2
from cython.parallel import parallel, prange

# Compile time optimizations
cimport numpy as np
cimport cython

# We will mostly use UINT8
DTYPE_UINT8 = np.uint8
DTYPE_UINT16 = np.uint16
DTYPE_FLOAT32 = np.float32
DTYPE_INT16 = np.int16

ctypedef np.uint8_t DTYPE_UINT8_t
ctypedef np.uint16_t DTYPE_UINT16_t
ctypedef np.float32_t DTYPE_FLOAT32_t
ctypedef np.int16_t DTYPE_INT16_t

@cython.boundscheck(False)
@cython.wraparound(False)
def _get_dist_cp(np.ndarray[DTYPE_FLOAT32_t, ndim=1] centroid_xy,
                 np.ndarray[DTYPE_FLOAT32_t, ndim=3] imlabxy_patch,
                 float compactness, int S):
    '''
        Function to rapidly compute distance from a centroid over a patch
    '''
    # Declare all variables ahead
    cdef int H
    cdef int W
    cdef int h
    cdef int w
    cdef float C

    H = imlabxy_patch.shape[0]
    W = imlabxy_patch.shape[1]
    C = compactness/S

    # Create new matrix to store distances
    dist = np.zeros((H, W), dtype=DTYPE_FLOAT32)

    # Creating a data view will make all operations much faster
    cdef DTYPE_FLOAT32_t[:, :] dist_view = dist

    # Now run through all variables
    for h in prange(H, nogil=True):
        for w in range(W):
            # Unroll the whole computation
            dist_view[h, w] = ((imlabxy_patch[h, w, 0] - centroid_xy[0])**2 + \
                               (imlabxy_patch[h, w, 1] - centroid_xy[1])**2 + \
                               (imlabxy_patch[h, w, 2] - centroid_xy[2])**2 + \
                             C*(imlabxy_patch[h, w, 3] - centroid_xy[3])**2 + \
                             C*(imlabxy_patch[h, w, 4] - centroid_xy[4])**2)

    return dist

The relevant setup.py file for compiling:

# Compilation tools
from distutils.core import Extension, setup
from Cython.Build import cythonize

# Scientific computing
import numpy as np

ext_modules = [
    Extension(
        "cassi_cp",
        ["cassi_cp.pyx"],
        extra_compile_args=['-fopenmp', '-march=native', '-O3', '-ffast-math'],
        extra_link_args=['-fopenmp'],
        include_dirs=[np.get_include()]
    )
]

setup(
    name='cassi_cp',
    ext_modules=cythonize(ext_modules)
)

Hope that helps!

@weiwenchuan
Copy link

Hi @vishwa91 , thanks for the prompt reply! So do you mean you wrote your own code and didn't use this fast-slic method? I didn't use cpython before but I'll definitely try your code. I also tried to write my own code (in Python) for single-step of SLIC but that code runs too slow, therefore I tried this fast-slic code. Did you test the efficiency (running time) of your own update function?

@vishwa91
Copy link
Author

@weiwenchuan -- yes, I used my own code. To compile the cpython code, you may need to run:

python setup.py build_ext --inplace

Regarding efficiency -- no I did not profile my code, but am hoping its fairly fast as the costliest step was done in cpython (practically C).

Hope that helps!

@weiwenchuan
Copy link

hi @vishwa91 thank you. I just tried it and it works (although the running time is higher than this fast-slic, the segmentation result is really clear). Thank you for sharing your code!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants