Skip to content

Commit 9dbeed3

Browse files
update
1 parent 300f4c8 commit 9dbeed3

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.atguigu.day8
2+
3+
import com.atguigu.day2.SensorSource
4+
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
5+
import org.apache.flink.api.scala._
6+
import org.apache.flink.table.api.EnvironmentSettings
7+
import org.apache.flink.table.api.scala._
8+
import org.apache.flink.table.functions.AggregateFunction
9+
import org.apache.flink.types.Row
10+
11+
object AggregateFunctionExample {
12+
def main(args: Array[String]): Unit = {
13+
val env = StreamExecutionEnvironment.getExecutionEnvironment
14+
env.setParallelism(1)
15+
16+
val stream = env.addSource(new SensorSource).filter(_.id.equals("sensor_1"))
17+
18+
val settings = EnvironmentSettings
19+
.newInstance()
20+
.useBlinkPlanner()
21+
.inStreamingMode()
22+
.build()
23+
24+
val tEnv = StreamTableEnvironment.create(env, settings)
25+
26+
val avgTemp = new AvgTemp()
27+
28+
// table api
29+
val table = tEnv.fromDataStream(stream, 'id, 'timestamp as 'ts, 'temperature as 'temp)
30+
31+
table
32+
.groupBy('id)
33+
.aggregate(avgTemp('temp) as 'avgTemp)
34+
.select('id, 'avgTemp)
35+
.toRetractStream[Row]
36+
// .print()
37+
38+
// sql 写法
39+
tEnv.createTemporaryView("sensor", table)
40+
41+
tEnv.registerFunction("avgTemp", avgTemp)
42+
43+
tEnv.sqlQuery(
44+
"""
45+
|SELECT
46+
| id, avgTemp(temp)
47+
| FROM
48+
| sensor
49+
| GROUP BY id""".stripMargin
50+
)
51+
.toRetractStream[Row].print()
52+
53+
54+
env.execute()
55+
}
56+
57+
// 累加器的类型
58+
class AvgTempAcc {
59+
var sum: Double = 0.0
60+
var count: Int = 0
61+
}
62+
63+
// 第一个泛型是温度值的类型
64+
// 第二个泛型是累加器的类型
65+
class AvgTemp extends AggregateFunction[Double, AvgTempAcc] {
66+
// 创建累加器
67+
override def createAccumulator(): AvgTempAcc = new AvgTempAcc
68+
69+
// 累加规则
70+
def accumulate(acc: AvgTempAcc, temp: Double): Unit = {
71+
acc.sum += temp
72+
acc.count += 1
73+
}
74+
75+
override def getValue(accumulator: AvgTempAcc): Double = {
76+
accumulator.sum / accumulator.count
77+
}
78+
}
79+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package com.atguigu.day8
2+
3+
import com.atguigu.day2.SensorSource
4+
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
5+
import org.apache.flink.api.scala._
6+
import org.apache.flink.table.api.EnvironmentSettings
7+
import org.apache.flink.table.api.scala._
8+
import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction}
9+
import org.apache.flink.types.Row
10+
import org.apache.flink.util.Collector
11+
12+
object TableAggregateFunctionExample {
13+
def main(args: Array[String]): Unit = {
14+
val env = StreamExecutionEnvironment.getExecutionEnvironment
15+
env.setParallelism(1)
16+
17+
val stream = env.addSource(new SensorSource).filter(_.id.equals("sensor_1"))
18+
19+
val settings = EnvironmentSettings
20+
.newInstance()
21+
.useBlinkPlanner()
22+
.inStreamingMode()
23+
.build()
24+
25+
val tEnv = StreamTableEnvironment.create(env, settings)
26+
27+
val top2Temp = new Top2Temp()
28+
29+
// table api
30+
val table = tEnv.fromDataStream(stream, 'id, 'timestamp as 'ts, 'temperature)
31+
32+
table
33+
.groupBy('id)
34+
.flatAggregate(top2Temp('temperature) as ('temp, 'rank))
35+
.select('id, 'temp, 'rank)
36+
.toRetractStream[Row]
37+
.print()
38+
39+
// sql
40+
tEnv.createTemporaryView("t", table)
41+
42+
tEnv.registerFunction("top2Temp", top2Temp)
43+
44+
// tEnv
45+
// .sqlQuery(
46+
// """
47+
// |SELECT id, top2Temp(temperature)
48+
// | FROM t GROUP BY id""".stripMargin)
49+
// .toRetractStream[Row]
50+
// .print()
51+
52+
53+
env.execute()
54+
}
55+
56+
// 累加器
57+
class Top2TempAcc {
58+
var highestTemp: Double = Double.MinValue
59+
var secondHighestTemp: Double = Double.MinValue
60+
}
61+
62+
// 第一个泛型是输出:(温度值,排名)
63+
// 第二个泛型是累加器
64+
class Top2Temp extends TableAggregateFunction[(Double, Int), Top2TempAcc] {
65+
override def createAccumulator(): Top2TempAcc = new Top2TempAcc
66+
67+
// 函数名必须是accumulate
68+
def accumulate(acc: Top2TempAcc, temp: Double): Unit = {
69+
if (temp > acc.highestTemp) {
70+
acc.secondHighestTemp = acc.highestTemp
71+
acc.highestTemp = temp
72+
} else if (temp > acc.secondHighestTemp) {
73+
acc.secondHighestTemp = temp
74+
}
75+
}
76+
77+
// 函数名必须是emitValue,用来发射计算结果
78+
def emitValue(acc: Top2TempAcc, out: Collector[(Double, Int)]): Unit = {
79+
out.collect(acc.highestTemp, 1)
80+
out.collect(acc.secondHighestTemp, 2)
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)