Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement medfilt1d_ng #2285

Open
1 of 4 tasks
kif opened this issue Sep 22, 2024 · 6 comments
Open
1 of 4 tasks

Implement medfilt1d_ng #2285

kif opened this issue Sep 22, 2024 · 6 comments
Assignees

Comments

@kif
Copy link
Member

kif commented Sep 22, 2024

  • in Python
  • in Cython
  • in OpenCL
  • With tests
@kif kif added enhancement performance profiling issues labels Sep 22, 2024
@kif kif self-assigned this Sep 22, 2024
@kif
Copy link
Member Author

kif commented Sep 25, 2024

After preliminary work in #2261 we validated it was possible. Performances are expected to be 5s for 2Mpix image using Python

@kif kif added the duplicate label Sep 25, 2024
@kif
Copy link
Member Author

kif commented Oct 5, 2024

Quick&dirty implementation in python:

def azimuthal_median():
    q_start = 0.2
    q_stop = 0.8
    signal = numpy.zeros(csr[2].size-1, dtype="float64")
    norm = numpy.zeros(csr[2].size-1, dtype="float64")
    norm2 = numpy.zeros(csr[2].size-1, dtype="float64")
    variance = numpy.zeros(csr[2].size-1, dtype="float64")
    for i,start in enumerate(csr[2][:-1]):
        stop = csr[2][i+1]
        tmp = numpy.sort(work1[start:stop])
        upper = numpy.cumsum(tmp["n"])
        last = upper[-1]
        lower = numpy.concatenate(([0],upper[:-1]))
        mask = numpy.logical_and(upper>=q_start*last, lower<=q_stop*last)
        tmp = tmp[mask]
        signal[i] = tmp["s"].sum(dtype="float64")
        variance[i] = tmp["v"].sum(dtype="float64")
        norm[i] = tmp["n"].sum(dtype="float64")
        norm2[i] = (tmp["n"]**2).sum(dtype="float64")
    return signal, variance, norm, norm2

@kif
Copy link
Member Author

kif commented Oct 5, 2024

Non-regression test:

%time res = cython_medfilt3(prep, csr[0], csr[1], csr[2], (0.0,1.0))
%timeit res = cython_medfilt3(prep, csr[0], csr[1], csr[2], (0.0,1.0))

ref = ai.integrate1d(img, 2500, method=("full", "csr", "cython"), unit="r_mm", error_model="poisson")
abs((ref.sum_signal -res[0])/ref.sum_signal).max(), \
abs((ref.sum_variance -res[1])/ref.sum_variance).max(),\
abs((ref.sum_normalization -res[2])/ref.sum_normalization).max(),\
abs((ref.sum_normalization2 -res[3])/ref.sum_normalization2).max()

@kif
Copy link
Member Author

kif commented Oct 5, 2024

Quick&dirty implementation in cython:

%%cython -a -c-fopenmp --link-args=-fopenmp

# distutils: language = c++
#cython: boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False

import numpy
from libcpp cimport bool
from libcpp.algorithm cimport sort
from cython cimport floating
from cython.parallel import prange

cdef struct float4:
    float a
    float s
    float v
    float n

cdef bool cmp(float4 a, float4 b) noexcept nogil:
    return True if a.a<b.a else False
    
cdef void sort_3(float4[::1] ary) noexcept nogil:
    cdef:
        int size
    size = ary.shape[0]
    sort(&ary[0], &ary[size-1]+1, cmp)
    # return numpy.asarray(ary)


def cython_medfilt3(float[:, ::1] prep, 
                    float[::1] data, 
                    int[::1]indices, 
                    int[::1] indptr, 
                    quant):
    cdef:
        int i, nbins, npix, idx, j, start, stop
        float sum_
        double s,v,n,n2,d
        float4 w
        float4[::1] work
        float qmin,qmax, q_start, q_stop
        float[::1] pixel
        
        double[::1] signal, norm, norm2, variance

    q_start = float(quant[0])
    q_stop  = float(quant[1])
    
    nbins = indptr.shape[0] - 1
    npix = indices.shape[0]
    signal = numpy.zeros(nbins, dtype="float64")
    norm = numpy.zeros(nbins, dtype="float64")
    norm2 = numpy.zeros(nbins, dtype="float64")
    variance = numpy.zeros(nbins, dtype="float64")
    # prep = preproc_cython(img, mask=ai.detector.mask, solidangle=ai.solidAngleArray(), error_model=ErrorModel.POISSON,split_result=3).ershape((-1,3))
    
    work = numpy.zeros(npix, dtype=numpy.dtype([('a','f4'),('s','f4'),('v','f4'),('n','f4')]))
    for i in prange(npix, nogil=True):
        j = indices[i]
        s = prep[j,0]
        v = prep[j,1]
        n = prep[j,2]
        d = data[i]
        w.a = s/n
        w.s = s * d
        w.v = v * d * d
        w.n = n * d
        work[i] = w
        
    for idx in prange(nbins, nogil=True, schedule="guided"): #, 
        start = indptr[idx] 
        stop = indptr[idx+1]
        
        v = s = n = n2 = 0.0
        sort_3(work[start:stop])
        sum_ = 0.0
        for i in range(start, stop):
            sum_ = sum_ + work[i].n
            work[i].a = sum_
        qmin = q_start * sum_
        qmax = q_stop * sum_
        # print(sum_, qmin, qmax)
        for i in range(start, stop):        
            if (0.0 if i==0 else work[i-1].a) >= qmin and work[i].a <= qmax:
                w = work[i]
                s = s + w.s
                v = v + w.v
                n = n + w.n
                n2 = n2 + w.n*w.n
        signal[idx] = s
        variance[idx] = v
        norm[idx] = n
        norm2[idx] = n2
    
    return numpy.asarray(signal), numpy.asarray(variance), numpy.asarray(norm), numpy.asarray(norm2), numpy.asarray(work)

@kif
Copy link
Member Author

kif commented Oct 8, 2024

Sort algorithm in OpenCL for variable size ensemble:

raison = 1.3
pas = [1,2,3,4,6,8,11]
last = pas[-1]
while last<size:
    last = ceil(last*raison)
    pas.append(last)
else:
    pas = pas[:-1]

def swap(lst, i, j):
    "swap two elements if needed, in place. return 1 if a swap occured."
    if lst[i]>lst[j]:
        lst[i],lst[j] = lst[j],lst[i]
        return 1
    else:
        return 0


def passe(lst, pas=1):
    size = len(lst)
    perm = []
    cnt = 0
    if 2*pas>=size:
        parallel=[]
        for i in range(0,size-pas):
            cnt+=swap(lst, i, i+pas)
            parallel.append((i, i+pas))
        perm.append(parallel)
    elif pas == 1:
        for j in range(2):
            parallel=[]
            for i in range(j,size-pas, 2):
                cnt+=swap(lst, i, i+pas)
                parallel.append((i, i+pas))
            perm.append(parallel)
    else:
        parallel=[]
        for i in range(0, size-pas, 2*pas):            
            for j in range(i, i+pas):
                k = j+pas
                if (k<size):
                    cnt+=swap(lst, j, k)
                    parallel.append((j, k))
        perm.append(parallel)        
        parallel=[]
        for i in range(pas, size-pas, 2*pas):            
            for j in range(i, i+pas):
                k = j+pas
                if (k<size):
                    cnt+=swap(lst, j, k)
                    parallel.append((j, k))
        perm.append(parallel)        
    return cnt, perm

def ParallelCombSort(lst):
    perm = []
    c = 0
    cnt = 0
    extra = 0
    for p in pas[::-1]:
        if p >= len(lst):
            continue
        c, pl = passe(lst, p)
        cnt +=c
        if pl:
            perm += pl
    while c:
        c, pl = passe(lst, 1)
        cnt +=c
        perm += pl
        extra+=1
    return perm, extra

Scales in 5log(n) passes

@kif
Copy link
Member Author

kif commented Oct 24, 2024

OpenCL version of the comb-sort algorithm:

%%cl_kernel

// returns 1 if swapped, else 0
int compare_and_swap(global volatile float* elements, int i, int j)
{
    float vi = elements[i];
    float vj = elements[j];
    if (vi>vj)
    {
        elements[i] = vj;
        elements[j] = vi;
        return 1;
    }
    else
        return 0;
}

// returns the number of swap performed
int passe(global volatile float* elements, 
          int size,
          int step,
          local volatile int* shared)
{
    int wg = get_local_size(0);
    int tid = get_local_id(0);
    int cnt = 0;
    int i, j, k;
    
    if (2*step>=size)
    {
        for (i=tid;i<size-step;i+=wg)
            cnt += compare_and_swap(elements, i, i+step);
    }
    else if (step == 1)
    {
        for (i=2*tid; i<size-step; i+=2*wg)
            cnt+=compare_and_swap(elements, i, i+step);
        barrier(CLK_GLOBAL_MEM_FENCE);
        for (i=2*tid+1; i<size-step; i+=2*wg)
            cnt+=compare_and_swap(elements, i, i+step);
    }
    else
    {
        for (i=tid*2*step; i<size-step; i+=2*step*wg)
        {
            for (j=i; j<i+step; j++)
            {
                k  = j + step;
                if (k<size)
                    cnt += compare_and_swap(elements, j, k);
            }
        }
        barrier(CLK_GLOBAL_MEM_FENCE);
        for (i=tid*2*step+step; i<size-step; i+=2*step*wg)
        {
            for (j=i; j<i+step; j++)
            {
                k  = j + step;
                if (k<size)
                    cnt += compare_and_swap(elements, j, k);
            }
        }
    }

    // local reduction to sum all swaps performed
    shared[tid] = cnt;
    barrier(CLK_LOCAL_MEM_FENCE);
    for (i=wg/2; i<1; i/=2)
    {
        if ((tid+i)<wg)
            shared[tid] += shared[tid+i];
        barrier(CLK_LOCAL_MEM_FENCE);
    }
    barrier(CLK_GLOBAL_MEM_FENCE);
    return shared[0];
}

int inline next_step(int step, float raison)
{
    return convert_int_rtp((float)step*raison);
}

int inline previous_step(int step, float raison)
{
    return convert_int_rtn((float)step/raison);
}

kernel void combsort(global volatile float* elements, 
                     int size)
{
    local volatile int shared[1024];
    int step = 11;     // magic value
    float raison=1.3f; // magic value
    int cnt;

    while (step<size)
        step=next_step(step, raison);
    if (get_local_id(0) == 0) printf("+ %d %d\n", step, size);

    while (step>=size)
        step=previous_step(step, raison);
    if (get_local_id(0) == 0) printf("- %d %d\n", step, size);
    
    for (step=step; step>0; step=previous_step(step, raison))
    {      
        cnt = passe(elements, size, step, shared);
        //if (get_local_id(0) == 0) printf("o %d %d %d\n", step, size, cnt);
    }

    step = 1;
    while (cnt){
        cnt = passe(elements, size, step, shared);
        //if (get_local_id(0) == 0) printf("= %d %d %d\n", step, size, cnt);
    }
    
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant