forked from Infrasys-AI/AISystem
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path06.srt
More file actions
1442 lines (1090 loc) · 28.4 KB
/
06.srt
File metadata and controls
1442 lines (1090 loc) · 28.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1
00:00:00,900 --> 00:00:04,425
字幕组:赵含霖 谢鑫鑫
2
00:00:05,075 --> 00:00:07,520
Hello 大家好,我是 ZOMI 酱
3
00:00:07,520 --> 00:00:09,938
又来到了没什么观看
4
00:00:09,938 --> 00:00:16,080
但是我依然在坚持的一节自动微分的系列的课堂当中
5
00:00:16,080 --> 00:00:21,760
这一节主要是讲反向操作符重载去实现自动微分
6
00:00:21,960 --> 00:00:26,915
那这个自动微分的方式更类似于 PyTorch 这个 AI 框架
7
00:00:26,915 --> 00:00:30,760
就是使用反向操作符重载的自动微分
8
00:00:30,760 --> 00:00:34,240
那一起来回顾一下什么叫做操作符重载
9
00:00:34,240 --> 00:00:37,304
下面这个操作符重载的这句话
10
00:00:37,304 --> 00:00:40,706
其实是我在 Wiki 或者百度上面粘过来的
11
00:00:40,706 --> 00:00:43,160
具体在哪粘我已经忘了
12
00:00:43,160 --> 00:00:43,898
简单的来说
13
00:00:43,898 --> 00:00:46,793
其实它只是利用了语言的多态性
14
00:00:46,793 --> 00:00:50,040
然后进行了一个重载
15
00:00:50,040 --> 00:00:53,420
下面这一段反倒是没什么用
16
00:00:53,420 --> 00:00:55,592
但是依旧在那放着的一句话
17
00:00:55,592 --> 00:00:58,570
就是讲操作符自动重载的微分的方式
18
00:00:58,570 --> 00:01:01,160
的一些过去的 AI 框架
19
00:01:01,160 --> 00:01:05,520
那最典型的一个代表就是经常用到的 Pytorch
20
00:01:05,520 --> 00:01:09,281
其中最重要的就是使用数据结构 Tap
21
00:01:09,281 --> 00:01:11,430
来记录整个计算流程
22
00:01:11,430 --> 00:01:13,960
也就是理解的计算图
23
00:01:13,960 --> 00:01:17,680
但是在 Pytorch 里面,它没有一个现实的计算图
24
00:01:18,000 --> 00:01:21,188
然后在反向求解梯度的时候去 replay
25
00:01:21,538 --> 00:01:23,106
去重放我的操作
26
00:01:23,106 --> 00:01:24,456
这么一种方式
27
00:01:24,456 --> 00:01:29,055
现在来简单的去回顾一下操作符重载的基本流程
28
00:01:29,055 --> 00:01:34,025
首先就是需要用语言的多态性对操作符进行重载
29
00:01:34,025 --> 00:01:36,437
定一个特殊的数据结构
30
00:01:36,437 --> 00:01:40,120
并且对每个计算进行重载的操作
31
00:01:40,795 --> 00:01:43,164
第二个就是有一个 Tap 的一个数据结构
32
00:01:43,164 --> 00:01:48,080
对数据输出和计算进行记录
33
00:01:48,080 --> 00:01:51,218
接着记录了每一次操作之后
34
00:01:51,218 --> 00:01:53,701
需要对每次操作进行遍历
35
00:01:53,701 --> 00:01:56,440
然后计算它的微分方式
36
00:01:56,440 --> 00:01:59,128
最后就是使用链式求导法则
37
00:01:59,128 --> 00:02:05,720
把刚才遍历得到的微分的结果进行累积
38
00:02:05,720 --> 00:02:09,400
这个就完成了整个操作符重载的流程了
39
00:02:10,120 --> 00:02:13,100
操作符重载其实已经多次讲到了
40
00:02:13,100 --> 00:02:17,800
它的优点就是实现起来只需要语言去提供多态的性能
41
00:02:17,800 --> 00:02:19,952
第二个就是它的应用性非常高
42
00:02:19,952 --> 00:02:24,189
操作符重载之后跟原生语言的编程方式是类似的
43
00:02:24,680 --> 00:02:25,872
所以大家都会说
44
00:02:25,872 --> 00:02:29,882
极度的模仿操作符重载的 PyTorch 的方式
45
00:02:29,882 --> 00:02:33,776
非常方便去理解和使用
46
00:02:33,776 --> 00:02:35,451
跟理解 python 代码的方式一样
47
00:02:35,451 --> 00:02:37,080
这个就是 PyTorch 的优点
48
00:02:37,400 --> 00:02:39,374
它的缺点也是非常明显
49
00:02:39,374 --> 00:02:43,474
上面用了一个 Tape 去记录大量的操作
50
00:02:44,200 --> 00:02:48,589
这个时候就需要对特殊的数据结构进行大量的读和写了
51
00:02:48,589 --> 00:02:50,044
遍历等操作了
52
00:02:50,044 --> 00:02:52,840
非常不利于高阶的微分实现
53
00:02:52,840 --> 00:02:54,790
高阶微分可能会在
54
00:02:54,790 --> 00:02:59,700
动力学、生物分子建模、物理方程模拟等
55
00:02:59,700 --> 00:03:02,920
非常常见的一些科学计算场景经常用到
56
00:03:02,920 --> 00:03:07,120
这个时候这种自动微分的方式非常不利于求解
57
00:03:07,440 --> 00:03:10,808
第二个就是类似于 While,If else 这些控制表达
58
00:03:10,808 --> 00:03:14,133
其实很难通过操作符去重载的
59
00:03:15,080 --> 00:03:17,480
下面来看看反向模式
60
00:03:18,800 --> 00:03:22,360
反向模式一般来说是比较简单好理解的
61
00:03:22,600 --> 00:03:25,166
又回到熟悉的图里面
62
00:03:25,166 --> 00:03:28,920
正向模式,假设我现在有一个 X 的输入
63
00:03:28,920 --> 00:03:33,052
然后我正向的就是每一次去计算每一个节点
64
00:03:33,052 --> 00:03:36,000
然后去计算中间变量的导数
65
00:03:36,000 --> 00:03:37,581
最后一个个计算
66
00:03:37,581 --> 00:03:41,796
然后得到最终输出的 f(x1,x2)这个输出
67
00:03:41,796 --> 00:03:43,653
对于 X 的导数
68
00:03:43,653 --> 00:03:46,120
这个就是每次正向计算的
69
00:03:46,120 --> 00:03:49,014
那反向计算就是我从最后一个
70
00:03:49,014 --> 00:03:52,480
每个中间变量关于最初的一个导数
71
00:03:52,480 --> 00:03:55,085
那从反向开始就是从后面开始
72
00:03:55,085 --> 00:03:59,160
计算每一条路径关于逆向输入的一个导数
73
00:03:59,160 --> 00:04:06,160
最后我就求得了δf 关于 x2 和δf 关于 x1 的所有的导数形式
74
00:04:06,160 --> 00:04:08,293
那在机器学习里面
75
00:04:08,293 --> 00:04:11,143
因为我的输入神经元非常的大量
76
00:04:11,143 --> 00:04:13,600
而我的输出类别有限
77
00:04:13,600 --> 00:04:14,980
在机器学习里面
78
00:04:14,980 --> 00:04:20,280
所以一般都会用到反向模式的自动微分的方式去实现
79
00:04:20,280 --> 00:04:25,720
那这个也是反向传播的一个最原始的 idea 或者数学原理
80
00:04:25,960 --> 00:04:27,673
后面了解了这一点
81
00:04:27,673 --> 00:04:30,857
看反向传播这个算法可能会更有感觉
82
00:04:31,960 --> 00:04:35,273
下面我想通过简单的几分钟的了解
83
00:04:35,273 --> 00:04:38,728
去跟着大家一起去回顾或者学习一下
84
00:04:38,728 --> 00:04:41,960
Pytorch 的 AutoDiff 是怎么去实现的
85
00:04:41,960 --> 00:04:43,729
这里面的所有的操作方式
86
00:04:43,729 --> 00:04:49,809
都是根据 Pytorch 的最核心的框架的一个原始理念
87
00:04:49,809 --> 00:04:51,809
然后去复现的
88
00:04:51,809 --> 00:04:58,273
首先需要 from typing Import List
NameTuple,Callable,Dict,Operational
89
00:04:58,273 --> 00:05:02,200
这一些简单的操作方便下面去一个加载的
90
00:05:02,200 --> 00:05:04,426
那这个 fresh_name 有什么用呢?
91
00:05:04,426 --> 00:05:04,431
fresh_name 这个函数是用来打印跟 Tape 相关的变量
92
00:05:04,431 --> 00:05:08,520
fresh_name 这个函数是用来打印跟 Tape 相关的变量
93
00:05:08,520 --> 00:05:12,153
这是我这个 f'v{name}
94
00:05:12,153 --> 00:05:15,480
这个 Name 就是记录下面的每一条 Tape
95
00:05:15,480 --> 00:05:19,736
假设我 x1 等于 V-1,x2 等于 V0
96
00:05:19,736 --> 00:05:22,073
V-1 又通过一个计算
97
00:05:22,073 --> 00:05:27,000
每一行每一次计算都有一个 Tape 去记录的
98
00:05:27,000 --> 00:05:29,629
所以我这里面通过 fresh_name
99
00:05:29,629 --> 00:05:33,160
去记录我每一次 Tape 到底是第几个
100
00:05:33,160 --> 00:05:36,277
然后_Name 第一个就是 1
101
00:05:36,277 --> 00:05:39,216
从 1 开始不断的去累积
102
00:05:39,216 --> 00:05:42,840
然后返回 V 等于多少个
103
00:05:42,840 --> 00:05:45,315
为了更加好的理解 Pytorch 里面的
104
00:05:45,315 --> 00:05:47,413
反向模式自动微分的实现
105
00:05:47,413 --> 00:05:49,481
实现的代码过程当中
106
00:05:49,481 --> 00:05:53,080
完全不依赖于 Pytorch 的 AutoGrid 的方式
107
00:05:53,080 --> 00:05:56,912
反倒是引入了一个新的类
108
00:05:56,912 --> 00:05:58,912
这个类叫做 Variable
109
00:05:58,912 --> 00:06:01,720
也就是类似于 Pytorch 里面的 Tensor
110
00:06:01,720 --> 00:06:03,337
在计算的时候
111
00:06:03,337 --> 00:06:05,976
实际上是从最后的损失函数
112
00:06:05,976 --> 00:06:09,080
或者 l 来去进行一个计算的
113
00:06:09,080 --> 00:06:12,030
程序当中每算一个张量 X 的值
114
00:06:12,030 --> 00:06:13,334
就是它的梯度的时候
115
00:06:13,334 --> 00:06:16,680
都会去计算 dl 到 dx 的一个导数
116
00:06:16,680 --> 00:06:21,028
然后反向模式就是从 dl 对 dl 自身的导数开始
117
00:06:21,028 --> 00:06:23,455
也就是 dl 对 dl 的导数等于 1
118
00:06:23,960 --> 00:06:26,079
回头看看上面这条公式
119
00:06:26,079 --> 00:06:30,280
V5 就是我的 y 对 y 对自身的导数是 1
120
00:06:30,280 --> 00:06:31,833
从这个讲式开始
121
00:06:31,833 --> 00:06:35,342
然后使用偏导数和链式法则进行传播
122
00:06:35,342 --> 00:06:37,450
也就是下面这条公式
123
00:06:37,450 --> 00:06:39,320
然后一步步的去算的
124
00:06:39,320 --> 00:06:42,878
下面代码实现可能还是比较简单
125
00:06:42,878 --> 00:06:46,440
我的 Variable 可以理解为简单的张量
126
00:06:46,440 --> 00:06:50,242
对于张量,一开始会初始化一个值叫做 Value
127
00:06:50,242 --> 00:06:53,800
通过这个值变成张量的成员变量
128
00:06:53,800 --> 00:06:56,908
然后 self.name 就是刚才的中间变量
129
00:06:56,908 --> 00:06:59,529
如果一开始没有输入 name
130
00:06:59,529 --> 00:07:01,617
它可能就直接使用 fresh_name()
131
00:07:01,617 --> 00:07:04,762
就是刚才上面的一个函数
132
00:07:04,762 --> 00:07:07,490
fresh_name(),然后不断的去累加 1
133
00:07:08,760 --> 00:07:11,586
接着下面这几个就比较有意思了
134
00:07:11,596 --> 00:07:13,315
Constant 其实是比较方便
135
00:07:13,315 --> 00:07:16,365
去打印查看的一个过程
136
00:07:16,365 --> 00:07:18,493
会通过 Constant 然后上下文
137
00:07:18,493 --> 00:07:21,216
去把当前的一个值打印出来
138
00:07:21,216 --> 00:07:23,160
还有当前的 Name 打印出来
139
00:07:23,160 --> 00:07:26,917
下面这几个就是回到一开始去实现
140
00:07:26,917 --> 00:07:30,376
或者上两节分享内容里面的一个实现
141
00:07:30,376 --> 00:07:32,356
只有一条简单的公式
142
00:07:32,356 --> 00:07:32,376
这里面有 5 个操作
143
00:07:32,376 --> 00:07:33,856
这里面有 5 个操作
144
00:07:33,856 --> 00:07:38,062
第一个就是*、+、-、sin 和 log
145
00:07:38,062 --> 00:07:42,646
一开始并没有去实现这几个函数
146
00:07:42,646 --> 00:07:47,480
而是返回了 ops_mul、ops_add、ops_sub
147
00:07:47,480 --> 00:07:49,550
在反向自动微分的时候
148
00:07:49,550 --> 00:07:51,778
其实最核心的就是一个 Tape
149
00:07:51,778 --> 00:07:55,798
用来跟踪 Variable 的所有的计算
150
00:07:55,798 --> 00:07:58,680
以便于后面用链式求导法则的
151
00:07:58,680 --> 00:08:01,426
这里面就出现了一个 Tape 的类
152
00:08:01,426 --> 00:08:03,426
Tape 的类的数就是 NameTuple
153
00:08:03,426 --> 00:08:05,400
它是一个 String
154
00:08:05,400 --> 00:08:08,914
我的输入或者我的记录的类有两个
155
00:08:08,914 --> 00:08:10,914
第一个是 Input,第二个是 Output
156
00:08:10,914 --> 00:08:14,840
那 Propagation 就是应用链式求导法则的
157
00:08:14,840 --> 00:08:17,225
告诉我的输入是什么,输出是什么
158
00:08:17,225 --> 00:08:23,285
值得注意的是这里面的输入是我的 dl 到 doutput
159
00:08:23,285 --> 00:08:27,000
输出是 dl 除以 dInput
160
00:08:27,000 --> 00:08:31,174
Tape 把所有原始计算的累积的 List 列表
161
00:08:31,174 --> 00:08:35,716
就是我要把所有的计算逆向的过程记录下来
162
00:08:35,716 --> 00:08:41,880
最终通过遍历的方式求得每一次反向的自动微分的操作
163
00:08:41,880 --> 00:08:46,480
下面有另外一个函数,叫做 reset_tape
164
00:08:46,480 --> 00:08:53,256
这个函数很简单,就是重新初始化整个 gradient_tape
165
00:08:53,256 --> 00:08:56,310
把 gradient_tape 重新初始化一遍
166
00:08:57,160 --> 00:09:02,090
下面来看看具体的每个原子操作怎么去实现
167
00:09:02,090 --> 00:09:05,325
刚才在 Variable 或者 Tensor 里面
168
00:09:05,325 --> 00:09:09,785
重载 mul、add、sub 这些原始操作的时候
169
00:09:09,785 --> 00:09:13,235
返回的是一个 ops_sub 这个原子操作
170
00:09:13,235 --> 00:09:17,160
看看现在这个原子操作具体实现了哪些功能
171
00:09:17,160 --> 00:09:20,300
正向的时候的计算比较简单
172
00:09:20,300 --> 00:09:24,624
首第一个传进来的 other 它也是一个 Variable 或者一个张量
173
00:09:24,624 --> 00:09:28,189
这里面自身其实它是一个张量
174
00:09:28,189 --> 00:09:32,716
所以两个张量相乘,需要通过 Variable 把它们包起来
175
00:09:32,716 --> 00:09:34,625
最后返回一个 X
176
00:09:34,625 --> 00:09:38,045
这个 X 正向计算的时候直接返回出去
177
00:09:38,045 --> 00:09:42,905
中间的这一坨就是为了在反向的时候去计算的
178
00:09:42,905 --> 00:09:46,710
反向的时候先不要去看反向的计算
179
00:09:46,710 --> 00:09:49,841
而是去看一下 Tape 具体做了哪些工作
180
00:09:49,841 --> 00:09:51,886
这个就是 Tape
181
00:09:51,886 --> 00:09:58,272
Tape 就是记录输入输出还有反向操作的一个闭包函数
182
00:09:58,272 --> 00:10:01,664
Tape 刚才也是重新声明了
183
00:10:01,664 --> 00:10:05,990
只是记录输入输出还有对应的操作
184
00:10:05,990 --> 00:10:08,276
对应的反向操作就是这个
185
00:10:08,276 --> 00:10:12,880
然后通过 Gradient 的 Tape 把当前的 Tape Append 进去
186
00:10:12,880 --> 00:10:17,160
就通过一个列表 List 来记录我所有的操作
187
00:10:17,160 --> 00:10:19,741
然后就会去遍历这个 Gradient 的 Tape
188
00:10:19,741 --> 00:10:23,578
去把每一次的操作逆向的求出来
189
00:10:23,578 --> 00:10:26,772
就把所有的正向的计算操作求一遍
190
00:10:26,772 --> 00:10:28,948
把反向的计算操作求一遍
191
00:10:28,948 --> 00:10:33,354
就求得了最终的 dl 对 dx1、dx2、dx3
192
00:10:33,354 --> 00:10:38,266
反向的微分的时候有一个函数叫做 propagate
193
00:10:38,266 --> 00:10:41,374
它的输入是 dl 对 doutput 的一个值
194
00:10:41,374 --> 00:10:44,828
这个反向就是我的损失函数对输出的一个导数
195
00:10:44,828 --> 00:10:47,347
那 dl 对 dx 就是我当前的一个值
196
00:10:47,347 --> 00:10:51,494
接着就是 dx 对 dself 的一个值就是 other
197
00:10:51,494 --> 00:10:54,544
dx 对 dother 的值就是我当前的数
198
00:10:54,544 --> 00:10:59,282
乘法里面可以看到根据乘数的求导法则
199
00:10:59,282 --> 00:11:00,288
就是这两个
200
00:11:00,288 --> 00:11:04,624
然后再求 dl 对 dself 还有 dl 对 dother 的一个值
201
00:11:04,624 --> 00:11:07,669
最后就把我的输出扔出来
202
00:11:07,669 --> 00:11:09,085
因为这里面有两个输出
203
00:11:09,085 --> 00:11:12,988
所以会把两个输出都同时返回出去
204
00:11:12,988 --> 00:11:17,290
那同样的,add 操作也是相同的方式去处理
205
00:11:17,290 --> 00:11:19,160
我的 sub 操作也是相同的
206
00:11:19,160 --> 00:11:23,722
那加减乘除里面可能会简单一点的,就是加和减
207
00:11:23,722 --> 00:11:25,753
加和减无论你怎么算
208
00:11:25,753 --> 00:11:30,410
它里面就是对自身的数进行求导等于 1
209
00:11:30,410 --> 00:11:32,748
如果你对另外一个数进行求导
210
00:11:32,748 --> 00:11:35,362
那就是减号,保留减号就等于-1
211
00:11:35,362 --> 00:11:39,581
sin 还有 log 这两个也是比较简单的
212
00:11:39,581 --> 00:11:42,772
log 就是 1 除以 self.value 就可以了
213
00:11:42,772 --> 00:11:45,460
然后如果你对另外一个数进行求导
214
00:11:45,460 --> 00:11:48,056
就是对 dx 乘以 dself
215
00:11:51,160 --> 00:11:54,392
在 Pytorch,TensorFlow 或者 MindSpore 里面
216
00:11:54,392 --> 00:11:57,828
如果不显式的去设置 self.autogrid
217
00:11:57,828 --> 00:12:00,104
或者实现一个自动微分的时候
218
00:12:00,104 --> 00:12:02,104
其实只是做了一个正向的计算
219
00:12:02,160 --> 00:12:05,158
在实际上需要反向计算的时候
220
00:12:05,158 --> 00:12:08,160
就需要去声明我这个函数需要进行反向
221
00:12:08,160 --> 00:12:11,160
那这里面的反向模式也是一样的
222
00:12:11,160 --> 00:12:14,242
首先通过一个函数 grad
223
00:12:14,242 --> 00:12:17,800
然后去声明我需要进行一个反向梯度的求解
224
00:12:17,800 --> 00:12:21,840
那输入有两个,第一个是 l,第二个是 results
225
00:12:21,840 --> 00:12:25,022
输出 results 它是一个 x
226
00:12:25,022 --> 00:12:26,174
代表它是一个 List
227
00:12:26,174 --> 00:12:32,217
里面就对应于需要求导的所有的函数
228
00:12:32,217 --> 00:12:34,713
l 就是我最后的 V0
229
00:12:34,713 --> 00:12:37,009
可以看到这里面的公式
230
00:12:37,009 --> 00:12:42,001
对应的是这个 results,从下往上求
231
00:12:42,001 --> 00:12:44,305
我的 l 就是 V5
232
00:12:44,305 --> 00:12:50,321
我的 results 就是 x1 和 x2,对应的 V-1 和 V0
233
00:12:52,160 --> 00:12:54,855
回到最核心的这个里面
234
00:12:54,855 --> 00:12:58,294
首先创建一个字典 dl_d
235
00:12:58,294 --> 00:13:00,954
dl_d 它是一个字典
236
00:13:00,954 --> 00:13:03,402
里面就记录了每个 dl 对 dx
237
00:13:03,402 --> 00:13:07,160
或者 d 中间的变量的所有的名字和数值
238
00:13:07,160 --> 00:13:11,196
然后最后一个 variable 等于 1
239
00:13:11,196 --> 00:13:13,577
所以把最后一个 l 的 name 拿出来
240
00:13:13,577 --> 00:13:15,566
然后丢给它作为 1
241
00:13:15,566 --> 00:13:19,406
可以看到这里面最后一个假设是 1
242
00:13:19,406 --> 00:13:21,093
这个是所有的前提
243
00:13:22,821 --> 00:13:24,869
gather_grad 这个内联函数呢
244
00:13:24,869 --> 00:13:27,058
主要是去把所有的 entry
245
00:13:27,058 --> 00:13:30,946
就是我的 grad 里面的所有的 Tape 的数值都记录下来
246
00:13:30,946 --> 00:13:32,794
丢给我的 dl_d
247
00:13:32,794 --> 00:13:37,978
也就是把所有的数值或者我的计算的过程放在我的 dl_d 里面
248
00:13:37,978 --> 00:13:40,970
为的就是方便我进行打印的时候操作
249
00:13:41,160 --> 00:13:44,487
这个时候可以看到 dl_d
250
00:13:44,487 --> 00:13:48,842
主要是去记录所有的 dl,d0
251
00:13:48,842 --> 00:13:51,160
这个具体的计算公式
252
00:13:51,160 --> 00:13:53,727