|
34 | 34 |
|
35 | 35 | package net.imglib2.algorithm.morphology.distance; |
36 | 36 |
|
| 37 | +import static org.junit.Assert.assertTrue; |
| 38 | + |
| 39 | +import java.util.ArrayList; |
37 | 40 | import java.util.Arrays; |
| 41 | +import java.util.HashSet; |
| 42 | +import java.util.List; |
38 | 43 | import java.util.Random; |
| 44 | +import java.util.Set; |
39 | 45 | import java.util.concurrent.ExecutionException; |
40 | 46 | import java.util.concurrent.ExecutorService; |
41 | 47 | import java.util.concurrent.Executors; |
|
46 | 52 | import org.junit.Test; |
47 | 53 |
|
48 | 54 | import net.imglib2.Cursor; |
| 55 | +import net.imglib2.Interval; |
49 | 56 | import net.imglib2.Localizable; |
50 | 57 | import net.imglib2.Point; |
51 | 58 | import net.imglib2.RandomAccess; |
|
58 | 65 | import net.imglib2.img.basictypeaccess.array.DoubleArray; |
59 | 66 | import net.imglib2.img.basictypeaccess.array.LongArray; |
60 | 67 | import net.imglib2.type.logic.BitType; |
| 68 | +import net.imglib2.type.numeric.IntegerType; |
61 | 69 | import net.imglib2.type.numeric.RealType; |
| 70 | +import net.imglib2.type.numeric.integer.LongType; |
62 | 71 | import net.imglib2.type.numeric.real.DoubleType; |
63 | 72 | import net.imglib2.util.Intervals; |
| 73 | +import net.imglib2.util.Localizables; |
64 | 74 | import net.imglib2.util.Pair; |
| 75 | +import net.imglib2.util.Util; |
65 | 76 | import net.imglib2.view.Views; |
66 | 77 |
|
67 | 78 | /** |
68 | 79 | * |
69 | 80 | * @author Philipp Hanslovsky |
70 | | - * |
| 81 | + * @author John Bogovic |
71 | 82 | */ |
72 | 83 | public class DistanceTransformTest |
73 | 84 | { |
@@ -159,9 +170,7 @@ private void testBinary( final DISTANCE_TYPE dt, final DistanceCalculator distan |
159 | 170 |
|
160 | 171 | private static void compareRAIofRealType( final RandomAccessibleInterval< ? extends RealType< ? > > ref, final RandomAccessibleInterval< ? extends RealType< ? > > comp, final double tolerance ) |
161 | 172 | { |
162 | | - Assert.assertArrayEquals( Intervals.dimensionsAsLongArray( ref ), Intervals.dimensionsAsLongArray( comp ) ); |
163 | | - Assert.assertArrayEquals( Intervals.minAsLongArray( ref ), Intervals.minAsLongArray( comp ) ); |
164 | | - Assert.assertArrayEquals( Intervals.maxAsLongArray( ref ), Intervals.maxAsLongArray( comp ) ); |
| 173 | + assertTrue( Intervals.equals( ref, comp ) ); |
165 | 174 | for ( final Pair< ? extends RealType< ? >, ? extends RealType< ? > > p : Views.flatIterable( Views.interval( Views.pair( ref, comp ), ref ) ) ) |
166 | 175 | { |
167 | 176 | Assert.assertEquals( p.getA().getRealDouble(), p.getB().getRealDouble(), tolerance ); |
@@ -440,12 +449,184 @@ private static < T extends RealType< T > > void checkDistance( |
440 | 449 | final double[] weights, |
441 | 450 | final DistanceCalculator distanceCalculator ) |
442 | 451 | { |
443 | | - for ( final Cursor< T > c = Views.iterable( dist ).localizingCursor(); c.hasNext(); ) |
| 452 | + for ( final Cursor< T > c = dist.localizingCursor(); c.hasNext(); ) |
444 | 453 | { |
445 | 454 | final double actual = c.next().getRealDouble(); |
446 | 455 | final double expected = atSamePosition( foreground, c ) ? 0.0 : distanceCalculator.dist( foreground, c, weights ); |
447 | 456 | Assert.assertEquals( expected, actual, 0.0 ); |
448 | 457 | } |
449 | 458 | } |
450 | 459 |
|
| 460 | + @Test |
| 461 | + public void testLabelPropagation() |
| 462 | + { |
| 463 | + /* |
| 464 | + * Iterate over numReplicates = [0..9] numDimensions = [2, 3] numLabels |
| 465 | + * = [1..5] |
| 466 | + */ |
| 467 | + final int firstReplicate = 0; |
| 468 | + final int lastReplicate = 9; |
| 469 | + |
| 470 | + final int firstNumDimensions = 2; |
| 471 | + final int lastNumDimensions = 3; |
| 472 | + |
| 473 | + final int firstNumLabels = 2; |
| 474 | + final int lastNumLabels = 5; |
| 475 | + |
| 476 | + final RandomAccessibleInterval< Localizable > parameters = Localizables.randomAccessibleInterval( |
| 477 | + Intervals.createMinMax( |
| 478 | + firstReplicate, firstNumDimensions, |
| 479 | + firstNumLabels, lastReplicate, |
| 480 | + lastNumDimensions, lastNumLabels ) ); |
| 481 | + |
| 482 | + parameters.forEach( params -> { |
| 483 | + |
| 484 | + @SuppressWarnings( "unused" ) |
| 485 | + final int replicate = params.getIntPosition( 0 ); |
| 486 | + final int numDimensions = params.getIntPosition( 1 ); |
| 487 | + final int numLabels = params.getIntPosition( 2 ); |
| 488 | + |
| 489 | + testLabelPropagationHelper( numDimensions, numLabels ); |
| 490 | + testLabelPropagationHelperParallel( numDimensions, numLabels ); |
| 491 | + } ); |
| 492 | + |
| 493 | + } |
| 494 | + |
| 495 | + /** |
| 496 | + * Creates an label and distances images with the requested number of dimensions (ndims), |
| 497 | + * and places nLabels points with non-zero label. Checks that the propagated labels correctly |
| 498 | + * reflect the nearest label (ties are allowed: any label equi-distant to a point passes). |
| 499 | + * |
| 500 | + * @param ndims number of dimensions |
| 501 | + * @param nLabels number of labels |
| 502 | + */ |
| 503 | + private void testLabelPropagationHelper( int ndims, int nLabels ) |
| 504 | + { |
| 505 | + |
| 506 | + final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray(); |
| 507 | + final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims ); |
| 508 | + |
| 509 | + final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels ); |
| 510 | + DistanceTransform.labelTransform( labels, 0 ); |
| 511 | + validateLabelsSet( "serial", points, labels ); |
| 512 | + } |
| 513 | + |
| 514 | + /** |
| 515 | + * Creates an label and distances images with the requested number of dimensions (ndims), |
| 516 | + * and places nLabels points with non-zero label. Checks that the propagated labels correctly |
| 517 | + * reflect the nearest label (ties are allowed: any label equi-distant to a point passes). |
| 518 | + * |
| 519 | + * @param ndims number of dimensions |
| 520 | + * @param nLabels number of labels |
| 521 | + */ |
| 522 | + private void testLabelPropagationHelperParallel( int ndims, int nLabels ) |
| 523 | + { |
| 524 | + |
| 525 | + final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray(); |
| 526 | + final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims ); |
| 527 | + final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels ); |
| 528 | + DistanceTransform.labelTransform( labels, 0, es, 3 * nThreads ); |
| 529 | + validateLabelsSet( "parallel", points, labels ); |
| 530 | + } |
| 531 | + |
| 532 | + private ArrayImg< LongType, LongArray > copyLongArrayImg( ArrayImg< LongType, LongArray > img ) |
| 533 | + { |
| 534 | + |
| 535 | + final long[] dataOrig = img.getAccessType().getCurrentStorageArray(); |
| 536 | + final long[] dataCopy = new long[ dataOrig.length ]; |
| 537 | + System.arraycopy( dataOrig, 0, dataCopy, 0, dataOrig.length ); |
| 538 | + return ArrayImgs.longs( dataCopy, img.dimensionsAsLongArray() ); |
| 539 | + } |
| 540 | + |
| 541 | + private static Point randomPointInInterval( final Random rng, final Interval itvl ) |
| 542 | + { |
| 543 | + final int[] coords = IntStream.range( 0, itvl.numDimensions() ).map( i -> { |
| 544 | + return rng.nextInt( ( int ) itvl.dimension( i ) ); |
| 545 | + } ).toArray(); |
| 546 | + return new Point( coords ); |
| 547 | + } |
| 548 | + |
| 549 | + private static < T extends RealType< T >, L extends IntegerType< L > > Set< PointAndLabel > initializeLabels( Random random, int numLabels, RandomAccessibleInterval< L > labels ) |
| 550 | + { |
| 551 | + labels.forEach( p -> p.setZero() ); // Initialize all labels to 0 |
| 552 | + Set< PointAndLabel > positions = new HashSet<>(); |
| 553 | + |
| 554 | + int currentLabel = 1; |
| 555 | + // Set numLabels different random positions to a non-zero label |
| 556 | + while ( positions.size() < numLabels ) |
| 557 | + { |
| 558 | + final Point pt = randomPointInInterval( random, labels ); |
| 559 | + if ( !positions.contains( pt ) ) |
| 560 | + { |
| 561 | + |
| 562 | + final PointAndLabel candidate = new PointAndLabel( currentLabel, pt.positionAsLongArray() ); |
| 563 | + if ( !positions.contains( candidate ) ) |
| 564 | + { |
| 565 | + positions.add( candidate ); |
| 566 | + labels.randomAccess().setPositionAndGet( pt ).setInteger( currentLabel ); |
| 567 | + currentLabel++; |
| 568 | + } |
| 569 | + |
| 570 | + } |
| 571 | + } |
| 572 | + return positions; |
| 573 | + } |
| 574 | + |
| 575 | + /** |
| 576 | + * Return the set of points within epsilon distance of the query point |
| 577 | + * |
| 578 | + * @param query point |
| 579 | + * @param pointSet set of candidate points |
| 580 | + * @param epsilon distance threshold |
| 581 | + * @return the set of close points |
| 582 | + */ |
| 583 | + private static List< PointAndLabel > closestSet( Localizable query, Set< PointAndLabel > pointSet, final double epsilon ) |
| 584 | + { |
| 585 | + |
| 586 | + final List< PointAndLabel > listOfEquidistant = new ArrayList<>(); |
| 587 | + |
| 588 | + double mindist = Double.MAX_VALUE; |
| 589 | + for ( PointAndLabel pt : pointSet ) |
| 590 | + { |
| 591 | + double dist = Util.distance( query, pt ); |
| 592 | + |
| 593 | + if ( Math.abs( dist - mindist ) < epsilon ) |
| 594 | + { |
| 595 | + listOfEquidistant.add( pt ); |
| 596 | + } |
| 597 | + else if ( dist < mindist ) |
| 598 | + { |
| 599 | + mindist = dist; |
| 600 | + listOfEquidistant.clear(); |
| 601 | + listOfEquidistant.add( pt ); |
| 602 | + } |
| 603 | + } |
| 604 | + |
| 605 | + return listOfEquidistant; |
| 606 | + } |
| 607 | + |
| 608 | + private static < T extends RealType< T >, L extends IntegerType< L > > void validateLabelsSet( final String prefix, final Set< PointAndLabel > points, final RandomAccessibleInterval< L > labels ) |
| 609 | + { |
| 610 | + final double EPS = 0.01; |
| 611 | + final Cursor< L > c = labels.cursor(); |
| 612 | + while ( c.hasNext() ) |
| 613 | + { |
| 614 | + c.fwd(); |
| 615 | + final boolean labelIsClosest = closestSet( c, points, EPS ).stream().anyMatch( p -> p.label == c.get().getIntegerLong() ); |
| 616 | + assertTrue( prefix + " point: " + Arrays.toString( c.positionAsLongArray() ), labelIsClosest ); |
| 617 | + } |
| 618 | + } |
| 619 | + |
| 620 | + private static class PointAndLabel extends Point |
| 621 | + { |
| 622 | + |
| 623 | + long label; |
| 624 | + |
| 625 | + public PointAndLabel( long label, long[] position ) |
| 626 | + { |
| 627 | + super( position ); |
| 628 | + this.label = label; |
| 629 | + } |
| 630 | + } |
| 631 | + |
451 | 632 | } |
0 commit comments