Skip to content

gene_ref_parser #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
"""
Allele reference database extractor module.
This module groups together tasks to parse allele data from the GenTar Reference DataBase.
This module groups together tasks to parse allele data from the GenTaR Reference DataBase.
"""
from typing import Any
import logging

import luigi
from luigi.contrib.spark import PySparkTask
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import collect_set
from airflow.sdk import Variable, asset
from airflow.hooks.base import BaseHook

from impc_etl.workflow import SmallPySparkTask
from impc_etl.workflow.config import ImpcConfig

from impc_etl.utils.airflow import create_input_asset, create_output_asset
from impc_etl.utils.spark import with_spark_postgres_session


task_logger = logging.getLogger("airflow.task")
dr_tag = Variable.get("data_release_tag")

gene_ref_parquet_asset = create_output_asset("gene_ref_parquet")
input_asset = create_input_asset("output")

conn = BaseHook.get_connection("reference_database")

gene_ref_cols = [
"ensembl_chromosome",
Expand Down Expand Up @@ -41,64 +48,31 @@
"human_gene_acc_id",
]


class ExtractGeneRef(SmallPySparkTask):
@asset.multi(
schedule=[input_asset],
outlets=[gene_ref_parquet_asset],
dag_id=f"{dr_tag}_extract_gene_ref_parser",
)
@with_spark_postgres_session
def extract_gene_ref():
"""
PySparkTask task to parse allele reference data from the GenTar reference database.
PySparkTask task to parse allele reference data from the GenTaR reference database.
"""

#: Name of the Spark task
name = "IMPC_Extract_Gene_Ref"

#: Reference DB connection JDBC string
ref_db_jdbc_connection_str = luigi.Parameter()

#: Reference DB user
ref_database_user = luigi.Parameter()
from pyspark.sql import SparkSession
from pyspark.sql.functions import collect_set

spark = SparkSession.builder.getOrCreate()

#: Reference DB password
ref_database_password = luigi.Parameter()

#: Path of the output directory where the new parquet file will be generated.
output_path = luigi.Parameter()
jdbc_connection_str = f"jdbc:postgresql://{conn.host}:{conn.port}/{conn.schema}"
login_name = conn.login
login_password = conn.password
output_path = gene_ref_parquet_asset.uri

def output(self):
"""
Returns the full parquet path as an output for the Luigi Task
(e.g. impc/dr16.0/parquet/gene_ref_parquet)
"""
return ImpcConfig().get_target(f"{self.output_path}gene_ref_parquet")

def app_options(self):
"""
Generates the options pass to the PySpark job
"""
return [
self.ref_db_jdbc_connection_str,
self.ref_database_user,
self.ref_database_password,
self.output().path,
]

def main(self, sc: SparkContext, *args: Any):
"""
DCC Extractor job runner
:param list argv: the list elements should be:
"""
jdbc_connection_str = args[0]
ref_database_user = args[1]
ref_database_password = args[2]
output_path = args[3]

db_properties = {
"user": ref_database_user,
"password": ref_database_password,
"driver": "org.postgresql.Driver",
}

spark = SparkSession(sc)

sql_query = """
task_logger.info(f"Connection: {jdbc_connection_str}")

sql_query = """
SELECT
mouse_gene.*,
mouse_gene_synonym.synonym,
Expand Down Expand Up @@ -133,18 +107,23 @@ def main(self, sc: SparkContext, *args: Any):
ON human_gene.id = human_gene_synonym_relation.human_gene_id
LEFT JOIN public.human_gene_synonym
ON human_gene_synonym_relation.human_gene_synonym_id = human_gene_synonym.id
"""
"""

mouse_gene_df = spark.read.jdbc(
jdbc_connection_str,
table=f"(SELECT CAST(id AS BIGINT) AS numericId, * FROM ({sql_query}) AS mouse_gene_mouse_synonym) AS mouse_gene_df",
properties=db_properties,
numPartitions=10,
column="numericId",
lowerBound=0,
upperBound=100000,
)
mouse_gene_df = mouse_gene_df.groupBy(
mouse_gene_df = spark.read \
.format("jdbc") \
.option("url", jdbc_connection_str) \
.option("dbtable", f"(SELECT CAST(id AS BIGINT) AS numericId, * FROM ({sql_query}) AS mouse_gene_mouse_synonym) AS mouse_gene_df" ) \
.option("user", login_name) \
.option("password", login_password) \
.option("driver", "org.postgresql.Driver") \
.option("numPartitions", 0) \
.option("partitionColumn", "numericId") \
.option("lowerBound", 0) \
.option("upperBound", 100000) \
.load()


mouse_gene_df = mouse_gene_df.groupBy(
[
col_name
for col_name in mouse_gene_df.columns
Expand All @@ -162,4 +141,5 @@ def main(self, sc: SparkContext, *args: Any):
collect_set("human_symbol_synonym").alias("human_symbol_synonym"),
collect_set("human_gene_acc_id").alias("human_gene_acc_id"),
)
mouse_gene_df.select(*gene_ref_cols).write.parquet(output_path)

mouse_gene_df.select(*gene_ref_cols).write.mode("ignore").parquet(output_path)
30 changes: 21 additions & 9 deletions impc_etl/utils/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,34 @@
task_logger = logging.getLogger("airflow.task")


def with_spark_session(func):
def with_spark_session(func, postgres_database=False):
@wraps(func)
def wrapper():
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf

spark_connection = BaseHook.get_connection("spark-conn")
task_logger.info(f"Spark connection URL: {spark_connection.get_uri()}")
task_logger.info("::group::SPARK LOGS")

spark = (
SparkSession.builder.appName(func.__name__)
.master(f"spark://{spark_connection.host}:{spark_connection.port}")
.config("spark.ui.showConsoleProgress", False)
.config("spark.ui.rolling.maxRetainedFiles", "5")
.config("spark.executor.logs.rolling.strategy", "time")
.config("spark.executor.logs.rolling.time.interval", "daily")
.config(
conf = ( SparkConf()
.setAppName(func.__name__)
.setMaster(f"spark://{spark_connection.host}:{spark_connection.port}")
.set("spark.ui.showConsoleProgress", False)
.set("spark.ui.rolling.maxRetainedFiles", "5")
.set("spark.executor.logs.rolling.strategy", "time")
.set("spark.executor.logs.rolling.time.interval", "daily")
.set(
"spark.driver.extraJavaOptions",
"-Dlog4j.configuration=file:/opt/airflow/log4j.properties",
)
)

if postgres_database:
conf = conf = conf.set("spark.jars.packages", "org.postgresql:postgresql:42.7.7")

spark = (
SparkSession.builder.config(conf=conf)
.getOrCreate()
)
spark_logger = logging.getLogger("spark")
Expand All @@ -40,3 +48,7 @@ def wrapper():
task_logger.info("::endgroup::")

return wrapper


def with_spark_postgres_session(func):
return with_spark_session(func, postgres_database=True)