Skip to content

Commit

Permalink
adding --sort-order option to SortSamSpark (#4545)
Browse files Browse the repository at this point in the history
* adding --sort-order option to SortSamSpark

adding a --sort-order option to SortSamSpark to let users specify the what order to sort in
enabling disabled tests
fixing the tests which weren't actually asserting anything

* closes #1260

* adding hack to get around HadoopGenomics/Hadoop-BAM#199
  created SplitSortingSamInputFormat which empirically fixes the issue although we don't necessarily completely understand the problem
  • Loading branch information
lbergelson authored Jun 11, 2018
1 parent 9ee101d commit 1751c85
Show file tree
Hide file tree
Showing 18 changed files with 293 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ private static void writeReadsADAM(
private static void saveAsShardedHadoopFiles(
final JavaSparkContext ctx, final String outputFile, final String referenceFile,
final SAMFormat samOutputFormat, final JavaRDD<SAMRecord> reads, final SAMFileHeader header,
final boolean writeHeader) throws IOException {
final boolean writeHeader) {
// Set the static header on the driver thread.
if (samOutputFormat == SAMFormat.CRAM) {
SparkCRAMOutputFormat.setHeader(header);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.parquet.avro.AvroParquetInputFormat;
Expand All @@ -24,17 +27,17 @@
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.*;
import org.broadinstitute.hellbender.utils.spark.SparkUtils;
import org.seqdoop.hadoop_bam.AnySAMInputFormat;
import org.seqdoop.hadoop_bam.BAMInputFormat;
import org.seqdoop.hadoop_bam.CRAMInputFormat;
import org.seqdoop.hadoop_bam.SAMRecordWritable;
import org.seqdoop.hadoop_bam.*;
import org.seqdoop.hadoop_bam.util.SAMHeaderReader;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;

/** Loads the reads from disk either serially (using samReaderFactory) or in parallel using Hadoop-BAM.
* The parallel code is a modified version of the example writing code from Hadoop-BAM.
Expand Down Expand Up @@ -69,6 +72,33 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
return getParallelReads(readFileName, referencePath, traversalParameters, 0);
}


/**
* this is a hack to work around https://github.com/HadoopGenomics/Hadoop-BAM/issues/199
*
* fix the problem by explicitly sorting the input file splits
*/
public static class SplitSortingSamInputFormat extends AnySAMInputFormat{
@SuppressWarnings("unchecked")
@Override
public List<InputSplit> getSplits(JobContext job) throws IOException {
final List<InputSplit> splits = super.getSplits(job);

if( splits.stream().allMatch(split -> split instanceof FileVirtualSplit || split instanceof FileSplit)) {
splits.sort(Comparator.comparing(split -> {
if (split instanceof FileVirtualSplit) {
return ((FileVirtualSplit) split).getPath().getName();
} else {
return ((FileSplit) split).getPath().getName();
}
}));
}

return splits;
}
}


/**
* Loads Reads using Hadoop-BAM. For local files, bam must have the fully-qualified path,
* i.e., file:///path/to/bam.bam.
Expand Down Expand Up @@ -102,7 +132,7 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
}

rdd2 = ctx.newAPIHadoopFile(
readFileName, AnySAMInputFormat.class, LongWritable.class, SAMRecordWritable.class,
readFileName, SplitSortingSamInputFormat.class, LongWritable.class, SAMRecordWritable.class,
conf);

JavaRDD<GATKRead> reads= rdd2.map(v1 -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.utils.spark.SparkUtils;
import picard.cmdline.programgroups.ReadDataManipulationProgramGroup;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator;
import scala.Tuple2;

import java.util.Collections;
import java.util.List;
Expand All @@ -27,35 +26,61 @@
public final class SortSamSpark extends GATKSparkTool {
private static final long serialVersionUID = 1L;

public static final String SORT_ORDER_LONG_NAME = "sort-order";

@Override
public boolean requiresReads() { return true; }

@Argument(doc="the output file path", shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, optional = false)
protected String outputFile;
private String outputFile;

@Argument(doc="sort order of the output file", fullName = SORT_ORDER_LONG_NAME, optional = true)
private SparkSortOrder sortOrder = SparkSortOrder.coordinate;

/**
* SortOrders that have corresponding implementations for spark.
* These correspond to a subset of {@link SAMFileHeader.SortOrder}.
*/
private enum SparkSortOrder {
coordinate(SAMFileHeader.SortOrder.coordinate),
queryname(SAMFileHeader.SortOrder.queryname);

private final SAMFileHeader.SortOrder order;

SparkSortOrder(SAMFileHeader.SortOrder order) {
this.order = order;
}

public SAMFileHeader.SortOrder getSamOrder() {
return order;
}
}

@Override
public List<ReadFilter> getDefaultReadFilters() {
return Collections.singletonList(ReadFilterLibrary.ALLOW_ALL_READS);
}

@Override
protected void onStartup() {
super.onStartup();
}

@Override
protected void runTool(final JavaSparkContext ctx) {
JavaRDD<GATKRead> reads = getReads();
int numReducers = getRecommendedNumReducers();
logger.info("Using %s reducers", numReducers);
final JavaRDD<GATKRead> reads = getReads();
final int numReducers = getRecommendedNumReducers();
logger.info("Using %d reducers", numReducers);

final SAMFileHeader header = getHeaderForReads();
header.setSortOrder(sortOrder.getSamOrder());

final SAMFileHeader readsHeader = getHeaderForReads();
ReadCoordinateComparator comparator = new ReadCoordinateComparator(readsHeader);
JavaRDD<GATKRead> sortedReads;
final JavaRDD<GATKRead> readsToWrite;
if (shardedOutput) {
sortedReads = reads
.mapToPair(read -> new Tuple2<>(read, null))
.sortByKey(comparator, true, numReducers)
.keys();
readsToWrite = SparkUtils.sortReadsAccordingToHeader(reads, header, numReducers);
} else {
sortedReads = reads; // sorting is done by writeReads below
readsToWrite = reads;
}
readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate);
writeReads(ctx, outputFile, sortedReads);
writeReads(ctx, outputFile, readsToWrite, header);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,27 @@ public static <T> void assertSorted(Iterable<T> iterable, Comparator<T> comparat
* assert that the iterator is sorted according to the comparator
*/
public static <T> void assertSorted(Iterator<T> iterator, Comparator<T> comparator){
assertSorted(iterator, comparator, null);
}


/**
* assert that the iterator is sorted according to the comparator
*/
public static <T> void assertSorted(Iterable<T> iterable, Comparator<T> comparator, String message){
assertSorted(iterable.iterator(), comparator, message);
}


/**
* assert that the iterator is sorted according to the comparator
*/
public static <T> void assertSorted(Iterator<T> iterator, Comparator<T> comparator, String message){
T previous = null;
while(iterator.hasNext()){
T current = iterator.next();
if( previous != null) {
Assert.assertTrue(comparator.compare(previous, current) <= 0, "Expected " + previous + " to be <= " + current);
Assert.assertTrue(comparator.compare(previous, current) <= 0, "Expected " + previous + " to be <= " + current + (message == null ? "" : "\n"+message));
}
previous = current;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
package org.broadinstitute.hellbender.tools.spark.pipelines;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SamReaderFactory;
import htsjdk.samtools.ValidationStringency;
import org.apache.spark.api.java.JavaRDD;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.hellbender.CommandLineProgramTest;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.SparkContextFactory;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.tools.spark.pipelines.SortSamSpark;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.test.ArgumentsBuilder;
import org.broadinstitute.hellbender.utils.test.BaseTest;
import org.broadinstitute.hellbender.utils.test.SamAssertionUtils;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.io.File;
import java.util.List;
import java.util.stream.Collectors;

public final class SortSamSparkIntegrationTest extends CommandLineProgramTest {

public static final String COUNT_READS_SAM = "count_reads.sam";
public static final String COORDINATE_SAM = "count_reads_sorted.sam";
public static final String QUERY_NAME_BAM = "count_reads.bam";
public static final String COORDINATE_BAM = "count_reads_sorted.bam";
public static final String COORDINATE_CRAM = "count_reads_sorted.cram";
public static final String QUERY_NAME_CRAM = "count_reads.cram";
public static final String REF = "count_reads.fasta";
public static final String CRAM = ".cram";
public static final String BAM = ".bam";
public static final String SAM = ".sam";

@DataProvider(name="sortbams")
public Object[][] sortBAMData() {
return new Object[][] {
{"count_reads.sam", "count_reads_sorted.sam", null, ".sam", "coordinate"},
{"count_reads.bam", "count_reads_sorted.bam", null, ".bam", "coordinate"},
{"count_reads.cram", "count_reads_sorted.cram", "count_reads.fasta", ".bam", "coordinate"},
{"count_reads.cram", "count_reads_sorted.cram", "count_reads.fasta", ".cram", "coordinate"},
{"count_reads.bam", "count_reads_sorted.bam", "count_reads.fasta", ".cram", "coordinate"},

//SortBamSpark is missing SORT_ORDER parameter https://github.com/broadinstitute/gatk/issues/1260
// {"count_reads.bam", "count_reads.bam", null, ".bam", "queryname"},
// {"count_reads.cram", "count_reads.cram", "count_reads.fasta", ".cram", "queryname"},
{COUNT_READS_SAM, COORDINATE_SAM, null, SAM, SAMFileHeader.SortOrder.coordinate},
{QUERY_NAME_BAM, COORDINATE_BAM, null, BAM, SAMFileHeader.SortOrder.coordinate},
{QUERY_NAME_CRAM, COORDINATE_CRAM, REF, BAM, SAMFileHeader.SortOrder.coordinate},
{QUERY_NAME_CRAM, COORDINATE_CRAM, REF, CRAM, SAMFileHeader.SortOrder.coordinate},
{QUERY_NAME_BAM, COORDINATE_BAM, REF, CRAM, SAMFileHeader.SortOrder.coordinate},

{COORDINATE_SAM, COUNT_READS_SAM, null, SAM, SAMFileHeader.SortOrder.queryname},
{COORDINATE_BAM, QUERY_NAME_BAM, null, BAM, SAMFileHeader.SortOrder.queryname},
{COORDINATE_CRAM, QUERY_NAME_CRAM, REF, BAM, SAMFileHeader.SortOrder.queryname},
{COORDINATE_CRAM, QUERY_NAME_CRAM, REF, CRAM, SAMFileHeader.SortOrder.queryname},
{COORDINATE_BAM, QUERY_NAME_BAM, REF, CRAM, SAMFileHeader.SortOrder.queryname},
};
}

Expand All @@ -32,42 +60,83 @@ public void testSortBAMs(
final String expectedOutputFileName,
final String referenceFileName,
final String outputExtension,
final String sortOrderName) throws Exception {
final File inputFile = new File(getTestDataDir(), inputFileName);
final File expectedOutputFile = new File(getTestDataDir(), expectedOutputFileName);
final SAMFileHeader.SortOrder sortOrder) throws Exception {
final File inputFile = getTestFile(inputFileName);
final File expectedOutputFile = getTestFile(expectedOutputFileName);
final File actualOutputFile = createTempFile("sort_sam", outputExtension);
File referenceFile = null == referenceFileName ? null : new File(getTestDataDir(), referenceFileName);
File referenceFile = null == referenceFileName ? null : getTestFile(referenceFileName);

final SamReaderFactory factory = SamReaderFactory.makeDefault();

ArgumentsBuilder args = new ArgumentsBuilder();
args.add("--input"); args.add(inputFile.getCanonicalPath());
args.add("--output"); args.add(actualOutputFile.getCanonicalPath());
args.addInput(inputFile);
args.addOutput(actualOutputFile);
if (null != referenceFile) {
args.add("--R");
args.add(referenceFile.getAbsolutePath());
args.addReference(referenceFile);
factory.referenceSequence(referenceFile);
}
args.add("--num-reducers"); args.add("1");
args.addArgument(SortSamSpark.SORT_ORDER_LONG_NAME, sortOrder.name());

//https://github.com/broadinstitute/gatk/issues/1260
// args.add("--SORT_ORDER");
// args.add(sortOrderName);
this.runCommandLine(args);

this.runCommandLine(args.getArgsArray());
//test files are exactly equal
SamAssertionUtils.assertSamsEqual(actualOutputFile, expectedOutputFile, ValidationStringency.DEFAULT_STRINGENCY, referenceFile);

SamAssertionUtils.samsEqualStringent(actualOutputFile, expectedOutputFile, ValidationStringency.DEFAULT_STRINGENCY, referenceFile);
//test sorting matches htsjdk
try(ReadsDataSource in = new ReadsDataSource(actualOutputFile.toPath(), factory )) {
BaseTest.assertSorted(Utils.stream(in).map(read -> read.convertToSAMRecord(in.getHeader())).iterator(), sortOrder.getComparatorInstance());
}
}

@Test(groups = "spark")
public void test() throws Exception {
final File unsortedBam = new File(getTestDataDir(), "count_reads.bam");
final File sortedBam = new File(getTestDataDir(), "count_reads_sorted.bam");
final File outputBam = createTempFile("sort_bam_spark", ".bam");
@Test(dataProvider="sortbams", groups="spark")
public void testSortBAMsSharded(
final String inputFileName,
final String unused,
final String referenceFileName,
final String outputExtension,
final SAMFileHeader.SortOrder sortOrder) {
final File inputFile = getTestFile(inputFileName);
final File actualOutputFile = createTempFile("sort_sam", outputExtension);
File referenceFile = null == referenceFileName ? null : getTestFile(referenceFileName);
ArgumentsBuilder args = new ArgumentsBuilder();
args.add("--"+ StandardArgumentDefinitions.INPUT_LONG_NAME); args.add(unsortedBam.getCanonicalPath());
args.add("--"+StandardArgumentDefinitions.OUTPUT_LONG_NAME); args.add(outputBam.getCanonicalPath());
args.add("--num-reducers"); args.add("1");
args.addInput(inputFile);
args.addOutput(actualOutputFile);
if (null != referenceFile) {
args.addReference(referenceFile);
}
args.addArgument(SortSamSpark.SORT_ORDER_LONG_NAME, sortOrder.name());
args.addBooleanArgument(GATKSparkTool.SHARDED_OUTPUT_LONG_NAME,true);
args.addArgument(GATKSparkTool.NUM_REDUCERS_LONG_NAME, "2");

this.runCommandLine(args);

final ReadsSparkSource source = new ReadsSparkSource(SparkContextFactory.getTestSparkContext());
final JavaRDD<GATKRead> reads = source.getParallelReads(actualOutputFile.getAbsolutePath(), referenceFile == null ? null : referenceFile.getAbsolutePath());

final SAMFileHeader header = source.getHeader(actualOutputFile.getAbsolutePath(),
referenceFile == null ? null : referenceFile.getAbsolutePath());

this.runCommandLine(args.getArgsArray());
final List<SAMRecord> reloadedReads = reads.collect().stream().map(read -> read.convertToSAMRecord(header)).collect(Collectors.toList());
BaseTest.assertSorted(reloadedReads.iterator(), sortOrder.getComparatorInstance(), reloadedReads.stream().map(SAMRecord::getSAMString).collect(Collectors.joining("\n")));
}

SamAssertionUtils.assertSamsEqual(outputBam, sortedBam);
@DataProvider
public Object[][] getInvalidSortOrders(){
return new Object[][]{
{SAMFileHeader.SortOrder.unknown},
{SAMFileHeader.SortOrder.unsorted},
{SAMFileHeader.SortOrder.duplicate}
};
}

@Test(expectedExceptions = CommandLineException.BadArgumentValue.class, dataProvider = "getInvalidSortOrders")
public void testBadSortOrders(SAMFileHeader.SortOrder badOrder){
final File unsortedBam = new File(getTestDataDir(), QUERY_NAME_BAM);
ArgumentsBuilder args = new ArgumentsBuilder();
args.addInput(unsortedBam);
args.addOutput(createTempFile("sort_bam_spark", BAM));
args.addArgument(SortSamSpark.SORT_ORDER_LONG_NAME, badOrder.toString());

this.runCommandLine(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@SQ SN:chr6 LN:101
@SQ SN:chr7 LN:404
@SQ SN:chr8 LN:202
@RG ID:0 SM:Hi,Mom!
@RG ID:0 SM:Hi,Mom! PL:ILLUMINA
@PG ID:1 PN:Hey! VN:2.0
both_reads_align_clip_marked 83 chr7 1 255 101M = 302 201 CAACAGAAGCNGGNATCTGTGTTTGTGTTTCGGATTTCCTGCTGAANNGNTTNTCGNNTCNNNNNNNNATCCCGATTTCNTTCCGCAGCTNACCTCCCAAN )'.*.+2,))&&'&*/)-&*-)&.-)&)&),/-&&..)./.,.).*&&,&.&&-)&&&0*&&&&&&&&/32/,01460&&/6/*0*/2/283//36868/& RG:Z:0
both_reads_present_only_first_aligns 89 chr7 1 255 101M * 0 0 CAACAGAAGCNGGNATCTGTGTTTGTGTTTCGGATTTCCTGCTGAANNGNTTNTCGNNTCNNNNNNNNATCCCGATTTCNTTCCGCAGCTNACCTCCCAAN )'.*.+2,))&&'&*/)-&*-)&.-)&)&),/-&&..)./.,.).*&&,&.&&-)&&&0*&&&&&&&&/32/,01460&&/6/*0*/2/283//36868/& RG:Z:0
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 1751c85

Please sign in to comment.