Skip to content

Commit 3a433fc

Browse files
committed
Adding new individual tasks for impc_web_api
1 parent 927550c commit 3a433fc

30 files changed

+6141
-0
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import csv
2+
import json
3+
import os
4+
import re
5+
6+
import luigi
7+
from luigi.contrib.spark import PySparkTask
8+
from pyspark import SparkContext
9+
from pyspark.sql import SparkSession, Window
10+
from pyspark.sql.functions import (
11+
col,
12+
first,
13+
explode,
14+
zip_with,
15+
struct,
16+
when,
17+
sum,
18+
collect_set,
19+
lit,
20+
concat,
21+
max,
22+
min,
23+
regexp_replace,
24+
split,
25+
arrays_zip,
26+
expr,
27+
concat_ws,
28+
countDistinct,
29+
array_contains,
30+
array_union,
31+
array,
32+
udf,
33+
row_number,
34+
avg,
35+
stddev,
36+
count,
37+
quarter,
38+
regexp_extract,
39+
array_distinct,
40+
lower,
41+
size,
42+
array_intersect,
43+
trim,
44+
explode_outer,
45+
desc,
46+
)
47+
from pyspark.sql.types import (
48+
DoubleType,
49+
IntegerType,
50+
BooleanType,
51+
ArrayType,
52+
StringType,
53+
StructType,
54+
StructField,
55+
)
56+
57+
from impc_etl.jobs.clean.specimen_cleaner import (
58+
MouseSpecimenCleaner,
59+
EmbryoSpecimenCleaner,
60+
)
61+
from impc_etl.jobs.extract import MGIStrainReportExtractor
62+
from impc_etl.jobs.extract.ontology_hierarchy_extractor import (
63+
OntologyTermHierarchyExtractor,
64+
)
65+
from impc_etl.jobs.load import ExperimentToObservationMapper
66+
from impc_etl.jobs.load.solr.gene_mapper import GeneLoader
67+
from impc_etl.jobs.load.solr.genotype_phenotype_mapper import GenotypePhenotypeLoader
68+
from impc_etl.jobs.load.solr.impc_images_mapper import ImpcImagesLoader
69+
from impc_etl.jobs.load.solr.mp_mapper import MpLoader
70+
from impc_etl.jobs.load.solr.pipeline_mapper import ImpressToParameterMapper
71+
from impc_etl.jobs.load.solr.stats_results_mapper import StatsResultsMapper
72+
from impc_etl.workflow import SmallPySparkTask
73+
from impc_etl.workflow.config import ImpcConfig
74+
75+
GENE_SUMMARY_MAPPINGS = {
76+
"mgi_accession_id": "mgiGeneAccessionId",
77+
"marker_symbol": "geneSymbol",
78+
"marker_name": "geneName",
79+
"marker_synonym": "synonyms",
80+
"significant_top_level_mp_terms": "significantTopLevelPhenotypes",
81+
"not_significant_top_level_mp_terms": "notSignificantTopLevelPhenotypes",
82+
"embryo_data_available": "hasEmbryoImagingData",
83+
"human_gene_symbol": "human_gene_symbols",
84+
"human_symbol_synonym": "human_symbol_synonyms",
85+
"production_centre": "production_centres",
86+
"phenotyping_centre": "phenotyping_centres",
87+
"allele_name": "allele_names",
88+
"ensembl_gene_id": "ensembl_gene_ids",
89+
}
90+
91+
92+
def get_lacz_expression_count(observations_df, lacz_lifestage):
93+
procedure_name = "Adult LacZ" if lacz_lifestage == "adult" else "Embryo LacZ"
94+
lacz_observations_by_gene = observations_df.where(
95+
(col("procedure_name") == procedure_name)
96+
& (col("observation_type") == "categorical")
97+
& (col("parameter_name") != "LacZ Images Section")
98+
& (col("parameter_name") != "LacZ Images Wholemount")
99+
)
100+
lacz_observations_by_gene = lacz_observations_by_gene.select(
101+
"gene_accession_id", "zygosity", lower("parameter_name").alias("parameter_name")
102+
).distinct()
103+
lacz_observations_by_gene = lacz_observations_by_gene.groupBy(
104+
"gene_accession_id"
105+
).agg(sum(when(col("parameter_name").isNotNull(), 1).otherwise(0)).alias("count"))
106+
lacz_observations_by_gene = lacz_observations_by_gene.withColumnRenamed(
107+
"count", f"{lacz_lifestage}ExpressionObservationsCount"
108+
)
109+
lacz_observations_by_gene = lacz_observations_by_gene.withColumnRenamed(
110+
"gene_accession_id", "id"
111+
)
112+
113+
return lacz_observations_by_gene
114+
115+
116+
def get_lacz_expression_data(observations_df, lacz_lifestage):
117+
procedure_name = "Adult LacZ" if lacz_lifestage == "adult" else "Embryo LacZ"
118+
119+
observations_df = observations_df.withColumn(
120+
"parameter_name", lower("parameter_name")
121+
)
122+
123+
lacz_observations = observations_df.where(
124+
(col("procedure_name") == procedure_name)
125+
& (col("observation_type") == "categorical")
126+
& (col("parameter_name") != "LacZ Images Section")
127+
& (col("parameter_name") != "LacZ Images Wholemount")
128+
)
129+
categories = [
130+
"expression",
131+
"tissue not available",
132+
"no expression",
133+
"imageOnly",
134+
"ambiguous",
135+
]
136+
lacz_observations_by_gene = lacz_observations.groupBy(
137+
"gene_accession_id",
138+
"zygosity",
139+
"parameter_name",
140+
).agg(
141+
*[
142+
sum(when(col("category") == category, 1).otherwise(0)).alias(
143+
to_camel_case(category.replace(" ", "_"))
144+
)
145+
for category in categories
146+
],
147+
collect_set(
148+
"parameter_stable_id",
149+
).alias("mutant_parameter_stable_ids"),
150+
)
151+
lacz_observations_by_gene = lacz_observations_by_gene.withColumn(
152+
"mutantCounts",
153+
struct(*[to_camel_case(category.replace(" ", "_")) for category in categories]),
154+
)
155+
156+
lacz_observations_by_gene = lacz_observations_by_gene.select(
157+
"gene_accession_id",
158+
"zygosity",
159+
"mutant_parameter_stable_ids",
160+
"parameter_name",
161+
"mutantCounts",
162+
).distinct()
163+
164+
wt_lacz_observations_by_strain = lacz_observations.where(
165+
col("biological_sample_group") == "control"
166+
)
167+
168+
wt_lacz_observations_by_strain = wt_lacz_observations_by_strain.groupBy(
169+
"parameter_name"
170+
).agg(
171+
*[
172+
sum(when(col("category") == category, 1).otherwise(0)).alias(
173+
to_camel_case(category.replace(" ", "_"))
174+
)
175+
for category in categories
176+
],
177+
collect_set(
178+
"parameter_stable_id",
179+
).alias("control_parameter_stable_ids"),
180+
)
181+
182+
wt_lacz_observations_by_strain = wt_lacz_observations_by_strain.withColumn(
183+
"controlCounts",
184+
struct(*[to_camel_case(category.replace(" ", "_")) for category in categories]),
185+
)
186+
187+
wt_lacz_observations_by_strain = wt_lacz_observations_by_strain.select(
188+
"parameter_name", "controlCounts"
189+
)
190+
191+
lacz_observations_by_gene = lacz_observations_by_gene.join(
192+
wt_lacz_observations_by_strain,
193+
["parameter_name"],
194+
"left_outer",
195+
)
196+
197+
lacz_images_by_gene = observations_df.where(
198+
(col("procedure_name") == procedure_name)
199+
& (col("observation_type") == "image_record")
200+
& (
201+
(lower(col("parameter_name")) == "lacz images section")
202+
| (lower(col("parameter_name")) == "lacz images wholemount")
203+
)
204+
)
205+
206+
lacz_images_by_gene = lacz_images_by_gene.select(
207+
struct(
208+
"parameter_stable_id",
209+
"parameter_name",
210+
).alias("expression_image_parameter"),
211+
"gene_accession_id",
212+
"zygosity",
213+
explode("parameter_association_name").alias("parameter_association_name"),
214+
).distinct()
215+
lacz_images_by_gene = lacz_images_by_gene.groupBy(
216+
"gene_accession_id", "zygosity", "parameter_association_name"
217+
).agg(
218+
collect_set("expression_image_parameter").alias("expression_image_parameters")
219+
)
220+
lacz_images_by_gene = lacz_images_by_gene.withColumnRenamed(
221+
"parameter_association_name", "parameter_name"
222+
)
223+
lacz_images_by_gene = lacz_images_by_gene.withColumn(
224+
"parameter_name", lower("parameter_name")
225+
)
226+
lacz_observations_by_gene = lacz_observations_by_gene.join(
227+
lacz_images_by_gene,
228+
["gene_accession_id", "zygosity", "parameter_name"],
229+
"left_outer",
230+
)
231+
lacz_observations_by_gene = lacz_observations_by_gene.withColumn(
232+
"lacZLifestage", lit(lacz_lifestage)
233+
)
234+
return lacz_observations_by_gene.distinct()
235+
236+
237+
def to_camel_case(snake_str):
238+
components = snake_str.split("_")
239+
# We capitalize the first letter of each component except the first one
240+
# with the 'title' method and join them together.
241+
return components[0] + "".join(x.title() for x in components[1:])
242+
243+
244+
def phenotype_term_zip_udf(x, y):
245+
return when(x.isNotNull(), struct(x.alias("id"), y.alias("name"))).otherwise(
246+
lit(None)
247+
)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from impc_etl.jobs.load.impc_web_api import (
2+
ImpcConfig,
3+
PySparkTask,
4+
SparkContext,
5+
SparkSession,
6+
col,
7+
collect_set,
8+
explode_outer,
9+
luigi,
10+
phenotype_term_zip_udf,
11+
)
12+
13+
14+
class ImpcBatchQueryMapper(PySparkTask):
15+
"""
16+
PySpark Task class to parse GenTar Product report data.
17+
"""
18+
19+
#: Name of the Spark task
20+
name: str = "ImpcBatchQueryMapper"
21+
22+
ortholog_mapping_report_tsv_path = luigi.Parameter()
23+
mp_hp_matches_csv_path = luigi.Parameter()
24+
25+
#: Path of the output directory where the new parquet file will be generated.
26+
output_path: luigi.Parameter = luigi.Parameter()
27+
28+
def requires(self):
29+
return [ImpcGeneStatsResultsMapper()]
30+
31+
def output(self):
32+
"""
33+
Returns the full parquet path as an output for the Luigi Task
34+
(e.g. impc/dr15.2/parquet/product_report_parquet)
35+
"""
36+
return ImpcConfig().get_target(
37+
f"{self.output_path}/impc_web_api/batch_query_data_parquet"
38+
)
39+
40+
def app_options(self):
41+
"""
42+
Generates the options pass to the PySpark job
43+
"""
44+
return [
45+
self.ortholog_mapping_report_tsv_path,
46+
self.mp_hp_matches_csv_path,
47+
self.input()[0].path,
48+
self.output().path,
49+
]
50+
51+
def main(self, sc: SparkContext, *args):
52+
"""
53+
Takes in a SparkContext and the list of arguments generated by `app_options` and executes the PySpark job.
54+
"""
55+
spark = SparkSession(sc)
56+
57+
# Parsing app options
58+
ortholog_mapping_report_tsv_path = args[0]
59+
mp_hp_matches_csv_path = args[1]
60+
gene_stats_results_json_path = args[2]
61+
output_path = args[3]
62+
63+
ortholog_mapping_df = spark.read.csv(
64+
ortholog_mapping_report_tsv_path, sep="\t", header=True
65+
)
66+
stats_results = spark.read.json(gene_stats_results_json_path)
67+
68+
ortholog_mapping_df = ortholog_mapping_df.select(
69+
col("Mgi Gene Acc Id").alias("mgiGeneAccessionId"),
70+
col("Human Gene Symbol").alias("humanGeneSymbol"),
71+
col("Hgnc Acc Id").alias("hgncGeneAccessionId"),
72+
).distinct()
73+
74+
stats_results = stats_results.join(
75+
ortholog_mapping_df, "mgiGeneAccessionId", how="left_outer"
76+
)
77+
78+
mp_matches_df = spark.read.csv(mp_hp_matches_csv_path, header=True)
79+
mp_matches_df = mp_matches_df.select(
80+
col("curie_x").alias("id"),
81+
col("curie_y").alias("hp_term_id"),
82+
col("label_y").alias("hp_term_name"),
83+
).distinct()
84+
85+
stats_mp_hp_df = stats_results.select(
86+
"statisticalResultId",
87+
"potentialPhenotypes",
88+
"intermediatePhenotypes",
89+
"topLevelPhenotypes",
90+
"significantPhenotype",
91+
)
92+
for phenotype_list_col in [
93+
"potentialPhenotypes",
94+
"intermediatePhenotypes",
95+
"topLevelPhenotypes",
96+
]:
97+
stats_mp_hp_df = stats_mp_hp_df.withColumn(
98+
phenotype_list_col[:-1], explode_outer(phenotype_list_col)
99+
)
100+
101+
stats_mp_hp_df = stats_mp_hp_df.join(
102+
mp_matches_df,
103+
(
104+
(col("significantPhenotype.id") == col("id"))
105+
| (col("potentialPhenotype.id") == col("id"))
106+
| (col("intermediatePhenotype.id") == col("id"))
107+
| (col("topLevelPhenotype.id") == col("id"))
108+
),
109+
how="left_outer",
110+
)
111+
stats_mp_hp_df = stats_mp_hp_df.withColumn(
112+
"humanPhenotype",
113+
phenotype_term_zip_udf(col("hp_term_id"), col("hp_term_name")),
114+
)
115+
stats_mp_hp_df = (
116+
stats_mp_hp_df.groupBy("statisticalResultId")
117+
.agg(collect_set("humanPhenotype").alias("humanPhenotypes"))
118+
.select("statisticalResultId", "humanPhenotypes")
119+
.distinct()
120+
)
121+
122+
stats_results = stats_results.join(stats_mp_hp_df, "statisticalResultId")
123+
124+
stats_results.write.parquet(output_path)

0 commit comments

Comments
 (0)