1+ package com .atguigu .day8
2+
3+ import com .atguigu .day2 .SensorSource
4+ import org .apache .flink .streaming .api .scala .StreamExecutionEnvironment
5+ import org .apache .flink .api .scala ._
6+ import org .apache .flink .table .api .EnvironmentSettings
7+ import org .apache .flink .table .api .scala ._
8+ import org .apache .flink .table .functions .{AggregateFunction , TableAggregateFunction }
9+ import org .apache .flink .types .Row
10+ import org .apache .flink .util .Collector
11+
12+ object TableAggregateFunctionExample {
13+ def main (args : Array [String ]): Unit = {
14+ val env = StreamExecutionEnvironment .getExecutionEnvironment
15+ env.setParallelism(1 )
16+
17+ val stream = env.addSource(new SensorSource ).filter(_.id.equals(" sensor_1" ))
18+
19+ val settings = EnvironmentSettings
20+ .newInstance()
21+ .useBlinkPlanner()
22+ .inStreamingMode()
23+ .build()
24+
25+ val tEnv = StreamTableEnvironment .create(env, settings)
26+
27+ val top2Temp = new Top2Temp ()
28+
29+ // table api
30+ val table = tEnv.fromDataStream(stream, ' id , ' timestamp as ' ts , ' temperature )
31+
32+ table
33+ .groupBy(' id )
34+ .flatAggregate(top2Temp(' temperature ) as (' temp , ' rank ))
35+ .select(' id , ' temp , ' rank )
36+ .toRetractStream[Row ]
37+ .print()
38+
39+ // sql
40+ tEnv.createTemporaryView(" t" , table)
41+
42+ tEnv.registerFunction(" top2Temp" , top2Temp)
43+
44+ // tEnv
45+ // .sqlQuery(
46+ // """
47+ // |SELECT id, top2Temp(temperature)
48+ // | FROM t GROUP BY id""".stripMargin)
49+ // .toRetractStream[Row]
50+ // .print()
51+
52+
53+ env.execute()
54+ }
55+
56+ // 累加器
57+ class Top2TempAcc {
58+ var highestTemp : Double = Double .MinValue
59+ var secondHighestTemp : Double = Double .MinValue
60+ }
61+
62+ // 第一个泛型是输出:(温度值,排名)
63+ // 第二个泛型是累加器
64+ class Top2Temp extends TableAggregateFunction [(Double , Int ), Top2TempAcc ] {
65+ override def createAccumulator (): Top2TempAcc = new Top2TempAcc
66+
67+ // 函数名必须是accumulate
68+ def accumulate (acc : Top2TempAcc , temp : Double ): Unit = {
69+ if (temp > acc.highestTemp) {
70+ acc.secondHighestTemp = acc.highestTemp
71+ acc.highestTemp = temp
72+ } else if (temp > acc.secondHighestTemp) {
73+ acc.secondHighestTemp = temp
74+ }
75+ }
76+
77+ // 函数名必须是emitValue,用来发射计算结果
78+ def emitValue (acc : Top2TempAcc , out : Collector [(Double , Int )]): Unit = {
79+ out.collect(acc.highestTemp, 1 )
80+ out.collect(acc.secondHighestTemp, 2 )
81+ }
82+ }
83+ }
0 commit comments