Skip to content

Commit 1a99586

Browse files
committed
test: label propagation with DistanceTransform
1 parent 19cb4ec commit 1a99586

File tree

1 file changed

+186
-5
lines changed

1 file changed

+186
-5
lines changed

src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java

Lines changed: 186 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,14 @@
3434

3535
package net.imglib2.algorithm.morphology.distance;
3636

37+
import static org.junit.Assert.assertTrue;
38+
39+
import java.util.ArrayList;
3740
import java.util.Arrays;
41+
import java.util.HashSet;
42+
import java.util.List;
3843
import java.util.Random;
44+
import java.util.Set;
3945
import java.util.concurrent.ExecutionException;
4046
import java.util.concurrent.ExecutorService;
4147
import java.util.concurrent.Executors;
@@ -46,6 +52,7 @@
4652
import org.junit.Test;
4753

4854
import net.imglib2.Cursor;
55+
import net.imglib2.Interval;
4956
import net.imglib2.Localizable;
5057
import net.imglib2.Point;
5158
import net.imglib2.RandomAccess;
@@ -58,16 +65,20 @@
5865
import net.imglib2.img.basictypeaccess.array.DoubleArray;
5966
import net.imglib2.img.basictypeaccess.array.LongArray;
6067
import net.imglib2.type.logic.BitType;
68+
import net.imglib2.type.numeric.IntegerType;
6169
import net.imglib2.type.numeric.RealType;
70+
import net.imglib2.type.numeric.integer.LongType;
6271
import net.imglib2.type.numeric.real.DoubleType;
6372
import net.imglib2.util.Intervals;
73+
import net.imglib2.util.Localizables;
6474
import net.imglib2.util.Pair;
75+
import net.imglib2.util.Util;
6576
import net.imglib2.view.Views;
6677

6778
/**
6879
*
6980
* @author Philipp Hanslovsky
70-
*
81+
* @author John Bogovic
7182
*/
7283
public class DistanceTransformTest
7384
{
@@ -159,9 +170,7 @@ private void testBinary( final DISTANCE_TYPE dt, final DistanceCalculator distan
159170

160171
private static void compareRAIofRealType( final RandomAccessibleInterval< ? extends RealType< ? > > ref, final RandomAccessibleInterval< ? extends RealType< ? > > comp, final double tolerance )
161172
{
162-
Assert.assertArrayEquals( Intervals.dimensionsAsLongArray( ref ), Intervals.dimensionsAsLongArray( comp ) );
163-
Assert.assertArrayEquals( Intervals.minAsLongArray( ref ), Intervals.minAsLongArray( comp ) );
164-
Assert.assertArrayEquals( Intervals.maxAsLongArray( ref ), Intervals.maxAsLongArray( comp ) );
173+
assertTrue( Intervals.equals( ref, comp ) );
165174
for ( final Pair< ? extends RealType< ? >, ? extends RealType< ? > > p : Views.flatIterable( Views.interval( Views.pair( ref, comp ), ref ) ) )
166175
{
167176
Assert.assertEquals( p.getA().getRealDouble(), p.getB().getRealDouble(), tolerance );
@@ -440,12 +449,184 @@ private static < T extends RealType< T > > void checkDistance(
440449
final double[] weights,
441450
final DistanceCalculator distanceCalculator )
442451
{
443-
for ( final Cursor< T > c = Views.iterable( dist ).localizingCursor(); c.hasNext(); )
452+
for ( final Cursor< T > c = dist.localizingCursor(); c.hasNext(); )
444453
{
445454
final double actual = c.next().getRealDouble();
446455
final double expected = atSamePosition( foreground, c ) ? 0.0 : distanceCalculator.dist( foreground, c, weights );
447456
Assert.assertEquals( expected, actual, 0.0 );
448457
}
449458
}
450459

460+
@Test
461+
public void testLabelPropagation()
462+
{
463+
/*
464+
* Iterate over numReplicates = [0..9] numDimensions = [2, 3] numLabels
465+
* = [1..5]
466+
*/
467+
final int firstReplicate = 0;
468+
final int lastReplicate = 9;
469+
470+
final int firstNumDimensions = 2;
471+
final int lastNumDimensions = 3;
472+
473+
final int firstNumLabels = 2;
474+
final int lastNumLabels = 5;
475+
476+
final RandomAccessibleInterval< Localizable > parameters = Localizables.randomAccessibleInterval(
477+
Intervals.createMinMax(
478+
firstReplicate, firstNumDimensions,
479+
firstNumLabels, lastReplicate,
480+
lastNumDimensions, lastNumLabels ) );
481+
482+
parameters.forEach( params -> {
483+
484+
@SuppressWarnings( "unused" )
485+
final int replicate = params.getIntPosition( 0 );
486+
final int numDimensions = params.getIntPosition( 1 );
487+
final int numLabels = params.getIntPosition( 2 );
488+
489+
testLabelPropagationHelper( numDimensions, numLabels );
490+
testLabelPropagationHelperParallel( numDimensions, numLabels );
491+
} );
492+
493+
}
494+
495+
/**
496+
* Creates an label and distances images with the requested number of dimensions (ndims),
497+
* and places nLabels points with non-zero label. Checks that the propagated labels correctly
498+
* reflect the nearest label (ties are allowed: any label equi-distant to a point passes).
499+
*
500+
* @param ndims number of dimensions
501+
* @param nLabels number of labels
502+
*/
503+
private void testLabelPropagationHelper( int ndims, int nLabels )
504+
{
505+
506+
final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray();
507+
final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims );
508+
509+
final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels );
510+
DistanceTransform.labelTransform( labels, 0 );
511+
validateLabelsSet( "serial", points, labels );
512+
}
513+
514+
/**
515+
* Creates an label and distances images with the requested number of dimensions (ndims),
516+
* and places nLabels points with non-zero label. Checks that the propagated labels correctly
517+
* reflect the nearest label (ties are allowed: any label equi-distant to a point passes).
518+
*
519+
* @param ndims number of dimensions
520+
* @param nLabels number of labels
521+
*/
522+
private void testLabelPropagationHelperParallel( int ndims, int nLabels )
523+
{
524+
525+
final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray();
526+
final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims );
527+
final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels );
528+
DistanceTransform.labelTransform( labels, 0, es, 3 * nThreads );
529+
validateLabelsSet( "parallel", points, labels );
530+
}
531+
532+
private ArrayImg< LongType, LongArray > copyLongArrayImg( ArrayImg< LongType, LongArray > img )
533+
{
534+
535+
final long[] dataOrig = img.getAccessType().getCurrentStorageArray();
536+
final long[] dataCopy = new long[ dataOrig.length ];
537+
System.arraycopy( dataOrig, 0, dataCopy, 0, dataOrig.length );
538+
return ArrayImgs.longs( dataCopy, img.dimensionsAsLongArray() );
539+
}
540+
541+
private static Point randomPointInInterval( final Random rng, final Interval itvl )
542+
{
543+
final int[] coords = IntStream.range( 0, itvl.numDimensions() ).map( i -> {
544+
return rng.nextInt( ( int ) itvl.dimension( i ) );
545+
} ).toArray();
546+
return new Point( coords );
547+
}
548+
549+
private static < T extends RealType< T >, L extends IntegerType< L > > Set< PointAndLabel > initializeLabels( Random random, int numLabels, RandomAccessibleInterval< L > labels )
550+
{
551+
labels.forEach( p -> p.setZero() ); // Initialize all labels to 0
552+
Set< PointAndLabel > positions = new HashSet<>();
553+
554+
int currentLabel = 1;
555+
// Set numLabels different random positions to a non-zero label
556+
while ( positions.size() < numLabels )
557+
{
558+
final Point pt = randomPointInInterval( random, labels );
559+
if ( !positions.contains( pt ) )
560+
{
561+
562+
final PointAndLabel candidate = new PointAndLabel( currentLabel, pt.positionAsLongArray() );
563+
if ( !positions.contains( candidate ) )
564+
{
565+
positions.add( candidate );
566+
labels.randomAccess().setPositionAndGet( pt ).setInteger( currentLabel );
567+
currentLabel++;
568+
}
569+
570+
}
571+
}
572+
return positions;
573+
}
574+
575+
/**
576+
* Return the set of points within epsilon distance of the query point
577+
*
578+
* @param query point
579+
* @param pointSet set of candidate points
580+
* @param epsilon distance threshold
581+
* @return the set of close points
582+
*/
583+
private static List< PointAndLabel > closestSet( Localizable query, Set< PointAndLabel > pointSet, final double epsilon )
584+
{
585+
586+
final List< PointAndLabel > listOfEquidistant = new ArrayList<>();
587+
588+
double mindist = Double.MAX_VALUE;
589+
for ( PointAndLabel pt : pointSet )
590+
{
591+
double dist = Util.distance( query, pt );
592+
593+
if ( Math.abs( dist - mindist ) < epsilon )
594+
{
595+
listOfEquidistant.add( pt );
596+
}
597+
else if ( dist < mindist )
598+
{
599+
mindist = dist;
600+
listOfEquidistant.clear();
601+
listOfEquidistant.add( pt );
602+
}
603+
}
604+
605+
return listOfEquidistant;
606+
}
607+
608+
private static < T extends RealType< T >, L extends IntegerType< L > > void validateLabelsSet( final String prefix, final Set< PointAndLabel > points, final RandomAccessibleInterval< L > labels )
609+
{
610+
final double EPS = 0.01;
611+
final Cursor< L > c = labels.cursor();
612+
while ( c.hasNext() )
613+
{
614+
c.fwd();
615+
final boolean labelIsClosest = closestSet( c, points, EPS ).stream().anyMatch( p -> p.label == c.get().getIntegerLong() );
616+
assertTrue( prefix + " point: " + Arrays.toString( c.positionAsLongArray() ), labelIsClosest );
617+
}
618+
}
619+
620+
private static class PointAndLabel extends Point
621+
{
622+
623+
long label;
624+
625+
public PointAndLabel( long label, long[] position )
626+
{
627+
super( position );
628+
this.label = label;
629+
}
630+
}
631+
451632
}

0 commit comments

Comments
 (0)