-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsstreaming-spark-out.py
173 lines (140 loc) · 5.97 KB
/
sstreaming-spark-out.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
'''
spark/bin/spark-submit --master local --driver-memory 4g --num-executors 2 --executor-memory 4g --packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.0.0 sstreaming-spark-out.py
'''
from py4j.java_gateway import java_import
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import expr
from pyspark.sql.functions import udf
from pyspark.sql.functions import avg
from pyspark.sql.functions import window
taxiFaresSchema = StructType([
StructField("rideId", LongType()), StructField("taxiId", LongType()),
StructField("driverId", LongType()), StructField(
"startTime", TimestampType()),
StructField("paymentType", StringType()), StructField("tip", FloatType()),
StructField("tolls", FloatType()), StructField("totalFare", FloatType())])
taxiRidesSchema = StructType([
StructField("rideId", LongType()), StructField("isStart", StringType()),
StructField("endTime", TimestampType()), StructField(
"startTime", TimestampType()),
StructField("startLon", FloatType()), StructField("startLat", FloatType()),
StructField("endLon", FloatType()), StructField("endLat", FloatType()),
StructField("passengerCnt", ShortType()), StructField(
"taxiId", LongType()),
StructField("driverId", LongType())])
def parse_data_from_kafka_message(sdf, schema):
from pyspark.sql.functions import split
assert sdf.isStreaming == True, "DataFrame doesn't receive streaming data"
# split attributes to nested array in one Column
col = split(sdf['value'], ',')
# now expand col to multiple top-level columns
for idx, field in enumerate(schema):
sdf = sdf.withColumn(field.name, col.getItem(idx).cast(field.dataType))
return sdf.select([field.name for field in schema])
spark = SparkSession.builder \
.appName("Spark Structured Streaming from Kafka") \
.getOrCreate()
# Subscription to kafka topics
sdfRides = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "taxirides") \
.option("startingOffsets", "latest") \
.load() \
.selectExpr("CAST(value AS STRING)")
sdfFares = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "taxifares") \
.option("startingOffsets", "latest") \
.load() \
.selectExpr("CAST(value AS STRING)")
sdfRides = parse_data_from_kafka_message(sdfRides, taxiRidesSchema)
sdfFares = parse_data_from_kafka_message(sdfFares, taxiFaresSchema)
def rides_data_cleaning(ridesSdf):
# remove all ride events that either started or ended outside NYC
LON_EAST, LON_WEST, LAT_NORTH, LAT_SOUTH = -73.7, -74.05, 41.0, 40.5
ridesSdf = ridesSdf.filter(
ridesSdf["startLon"].between(LON_WEST, LON_EAST) &
ridesSdf["startLat"].between(LAT_SOUTH, LAT_NORTH) &
ridesSdf["endLon"].between(LON_WEST, LON_EAST) &
ridesSdf["endLat"].between(LAT_SOUTH, LAT_NORTH))
# keep only finished ride events
ridesSdf = ridesSdf.filter(ridesSdf["isStart"] == "END")
return ridesSdf
sdfRides = rides_data_cleaning(sdfRides)
# Watermark defines how much a timestamp can lag behind the maximum event time seen so far
# Are used for efficient joins
# Apply watermarks on event-time columns
sdfFaresWithWatermark = sdfFares \
.selectExpr("rideId AS rideId_fares", "startTime", "totalFare", "tip") \
.withWatermark("startTime", "30 minutes") # maximal delay
# A Fares event would be kept up to 30 minutes to match it with Ride even
sdfRidesWithWatermark = sdfRides \
.selectExpr("rideId", "endTime", "driverId", "taxiId",
"startLon", "startLat", "endLon", "endLat") \
.withWatermark("endTime", "30 minutes") # maximal delay
# Join with event-time constraints and aggregate
sdf = sdfFaresWithWatermark \
.join(sdfRidesWithWatermark,
expr("""
rideId_fares = rideId AND
endTime > startTime AND
endTime <= startTime + interval 2 hours
"""))
nbhds_df = spark.read.json("nbhd.jsonl") # easy loading data
lookupdict = nbhds_df.select(
"name", "coord").rdd.collectAsMap() # cast the DataFrame
broadcastVar = spark.sparkContext.broadcast(
lookupdict) # use broadcastVar.value from now on
# Approx manhattan bbox
manhattan_bbox = [[-74.0489866963, 40.681530375], [-73.8265135518, 40.681530375],
[-73.8265135518, 40.9548628598], [-74.0489866963, 40.9548628598], [-74.0489866963, 40.681530375]]
def isPointInPath(x, y, poly):
"""check if point x, y is in poly
poly -- a list of tuples [(x, y), (x, y), ...]"""
num = len(poly)
i = 0
j = num - 1
c = False
for i in range(num):
if ((poly[i][1] > y) != (poly[j][1] > y)) and \
(x < poly[i][0] + (poly[j][0] - poly[i][0]) * (y - poly[i][1]) /
(poly[j][1] - poly[i][1])):
c = not c
j = i
return c
def find_nbhd(lon, lat):
'''takes geo point as lon, lat floats and returns name of neighborhood it belongs to
needs broadcastVar available'''
if not isPointInPath(lon, lat, manhattan_bbox):
return "Other"
for name, coord in broadcastVar.value.items():
if isPointInPath(lon, lat, coord):
return str(name) # cast unicode->str
return "Other" # geo-point not in neighborhoods
find_nbhd_udf = udf(find_nbhd, StringType())
sdf = sdf.withColumn("stopNbhd", find_nbhd_udf("endLon", "endLat"))
sdf = sdf.withColumn("startNbhd", find_nbhd_udf("startLon", "startLat"))
# query = sdf.groupBy("driverId").count()
# query.writeStream \
# .outputMode("append") \
# .format("console") \
# .option("truncate", False) \
# .start() \
# .awaitTermination()
tips = sdf \
.groupBy(
window("endTime", "30 minutes", "10 minutes"),
"stopNbhd") \
.agg(avg("tip"))
tips.writeStream \
.outputMode("append") \
.format("console") \
.queryName("tipss") \
.option("truncate", False) \
.start() \
.awaitTermination()