-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjacob_list_fun_model.v
369 lines (329 loc) · 10.7 KB
/
jacob_list_fun_model.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
Require Import vcfloat.VCFloat.
Require Import floatlib.
Require Import Coq.Lists.List. Import ListNotations.
Set Bullet Behavior "Strict Subproofs".
Require Import fma_floating_point_model.
Section WITH_NANS.
Context {NANS: Nans}.
Definition diagmatrix t := list (ftype t).
Definition invert_diagmatrix {t} (v: diagmatrix t) : diagmatrix t :=
map (BDIV (Zconst t 1)) v.
Definition diagmatrix_vector_mult {t}: diagmatrix t -> vector t -> vector t :=
map2 BMULT.
Definition diagmatrix_matrix_mult {t} (v: diagmatrix t) (m: matrix t) : matrix t :=
map2 (fun d => map (BMULT d)) v m.
Definition diag_of_matrix {t: type} (m: matrix t) : diagmatrix t :=
map (fun i => matrix_index m i i) (seq 0 (matrix_rows_nat m)).
Definition remove_diag {t} (m: matrix t) : matrix t :=
matrix_by_index (matrix_rows_nat m) (matrix_rows_nat m)
(fun i j => if Nat.eq_dec i j then Zconst t 0 else matrix_index m i j).
Definition matrix_of_diag {t} (d: diagmatrix t) : matrix t :=
matrix_by_index (length d) (length d)
(fun i j => if Nat.eq_dec i j then nth i d (Zconst t 0) else Zconst t 0).
Definition jacobi_iter {t: type} (A1: diagmatrix t) (A2: matrix t) (b: vector t) (x: vector t) : vector t :=
diagmatrix_vector_mult (invert_diagmatrix A1) (vector_sub b (matrix_vector_mult A2 x)).
Definition jacobi_residual {t: type} (A1: diagmatrix t) (A2: matrix t) (b: vector t) (x: vector t) : vector t :=
diagmatrix_vector_mult A1 (vector_sub (jacobi_iter A1 A2 b x) x).
Definition going {t} (s acc: ftype t) :=
andb (Binary.is_finite (fprec t) (femax t) s) (BCMP Lt false s acc).
Fixpoint iter_stop {t} {A} (norm2: A -> ftype t) (residual: A -> A) (f : A -> A) (n:nat) (acc: ftype t) (x:A) :=
let y := f x in
let s := norm2 (residual x) in
match n with
| O => (s, x)
| S n' => if going s acc
then iter_stop norm2 residual f n' acc y
else (s,x)
end.
Definition jacobi_n {t: type} (A: matrix t) (b: vector t) (x: vector t) (n: nat) : vector t :=
let A1 := diag_of_matrix A in
let A2 := remove_diag A in
Nat.iter n (jacobi_iter A1 A2 b) x.
Definition dist2 {t: type} (x y: vector t) := norm2 (vector_sub x y).
Definition jacobi {t: type} (A: matrix t) (b: vector t) (x: vector t) (acc: ftype t) (n: nat) :
ftype t * vector t :=
let A1 := diag_of_matrix A in
let A2 := remove_diag A in
iter_stop norm2 (jacobi_residual A1 A2 b) (jacobi_iter A1 A2 b) (Nat.pred n) acc x.
End WITH_NANS.
Module Experiment.
(***************** Some sanity checks about diag_of_matrix and matrix_of_diag ***)
(* This turned out to be much lengthier than I expected. I had to devel
a whole theory of extensional matrices. It started to feel like I was
recapitulating all of MathComp. My conclusion
is that all of these lemmas should be done at the MathComp level
and not at the list-of-lists level. None of these lemmas are needed
by the VST proofs, for example. *)
Section WITH_NANS.
Context {NANS: Nans}.
Local Ltac inv H := inversion H; clear H; subst.
Lemma length_diag_of_matrix: forall {t} (m: matrix t),
matrix_cols_nat m (matrix_rows_nat m) ->
length (diag_of_matrix m) = matrix_rows_nat m.
Proof.
intros.
unfold diag_of_matrix.
rewrite length_map. rewrite length_seq. auto.
Qed.
Lemma rows_matrix_of_diag: forall {t} (v: diagmatrix t),
matrix_rows_nat (matrix_of_diag v) = length v.
Proof.
intros.
unfold matrix_of_diag.
apply matrix_by_index_rows.
Qed.
Lemma cols_matrix_of_diag: forall {t} (v: diagmatrix t),
matrix_cols_nat (matrix_of_diag v) (length v).
Proof.
intros.
unfold matrix_of_diag.
apply matrix_by_index_cols.
Qed.
Lemma diag_of_matrix_of_diag:
forall {t} (d: diagmatrix t),
diag_of_matrix (matrix_of_diag d) = d.
Proof.
intros.
unfold diag_of_matrix, matrix_of_diag.
apply (all_nth_eq (Zconst t 0));
rewrite length_map, length_seq, matrix_by_index_rows. auto.
intros.
set (f := fun _ => _).
transitivity (nth i (map f (seq 0 (length d))) (f (length d))).
f_equal. subst f. simpl.
unfold matrix_by_index.
unfold matrix_index.
rewrite nth_overflow; auto.
rewrite nth_overflow; auto.
simpl; lia.
rewrite length_map. rewrite length_seq. lia.
rewrite map_nth.
rewrite seq_nth by auto.
simpl.
subst f. simpl.
rewrite matrix_by_index_index by auto.
destruct (Nat.eq_dec i); try contradiction; auto.
Qed.
Lemma Forall_diag_of_matrix {t}: forall (P: ftype t -> Prop) (m: matrix t),
matrix_cols_nat m (matrix_rows_nat m) ->
Forall (Forall P) m -> Forall P (diag_of_matrix m).
Proof.
intros.
apply Forall_map.
apply Forall_seq.
intros.
red in H.
unfold matrix_index.
unfold matrix_rows_nat in *.
rewrite Forall_nth in H0.
specialize (H0 i nil ltac:(lia)).
rewrite Forall_nth in H0.
apply (H0 i (Zconst t 0)).
rewrite Forall_forall in H.
specialize (H (nth i m nil)).
rewrite H. lia.
apply nth_In. lia.
Qed.
Lemma matrix_binop_by_index:
forall {t} (op: ftype t -> ftype t -> ftype t) (m1 m2: matrix t) cols,
matrix_rows_nat m1 = matrix_rows_nat m2 ->
matrix_cols_nat m1 cols -> matrix_cols_nat m2 cols ->
Forall2 (Forall2 feq) (map2 (map2 op) m1 m2)
(matrix_by_index (matrix_rows_nat m1) cols (fun i j => op (matrix_index m1 i j) (matrix_index m2 i j))).
Proof.
intros.
apply (matrix_extensionality _ _ cols); auto.
-
rewrite matrix_by_index_rows.
clear H0 H1.
revert m2 H; induction m1; destruct m2; simpl; intros; inv H; auto.
f_equal; eauto.
-
clear H.
revert m2 H1; induction H0; destruct m2; simpl; intros; constructor.
inv H1.
unfold uncurry, map2.
rewrite length_map.
rewrite length_combine.
rewrite H4. apply Nat.min_id.
apply IHForall.
inv H1; auto.
-
apply matrix_by_index_cols.
-
intros.
assert (matrix_rows_nat (map2 (map2 op) m1 m2) = matrix_rows_nat m1). {
clear - H. revert m2 H; induction m1; destruct m2; simpl; intros; inv H; f_equal; eauto.
}
rewrite H4 in *.
rewrite matrix_by_index_index; auto.
revert m2 H H1 i H2 H4; induction m1; destruct m2; simpl; intros; inv H.
+ lia.
+ destruct i; simpl.
* clear IHm1.
unfold matrix_index.
unfold map2 at 1. unfold combine. simpl.
unfold map2.
inv H0. inv H1.
clear - H3 H5.
revert j H3 l H5; induction a; destruct j,l; simpl; intros; inv H5; auto.
inv H3. inv H3.
simpl in H3.
eapply IHa; eauto. lia.
* unfold matrix_add.
change (map2 (map2 ?f) (a::m1) (l::m2)) with (map2 f a l :: map2 (map2 f) m1 m2).
repeat change (matrix_index (?x::?y) (S i) j) with (matrix_index y i j).
inv H1. inv H0.
eapply IHm1; eauto. lia.
Qed.
Lemma matrix_rows_nat_remove_diag: forall {t} (m: matrix t),
matrix_cols_nat m (matrix_rows_nat m) ->
matrix_rows_nat m = matrix_rows_nat (remove_diag m).
Proof.
intros.
symmetry;
apply matrix_by_index_rows.
Qed.
Lemma matrix_rows_nat_matrix_binop:
forall {t} (op: ftype t -> ftype t -> ftype t) (m1 m2: matrix t),
matrix_rows_nat m1 = matrix_rows_nat m2 ->
matrix_rows_nat (map2 (map2 op) m1 m2) = matrix_rows_nat m1.
Proof.
intros.
unfold matrix_rows_nat in *.
unfold map2.
rewrite length_map.
rewrite length_combine.
lia.
Qed.
Lemma matrix_cols_nat_matrix_binop:
forall {t} (op: ftype t -> ftype t -> ftype t) (m1 m2: matrix t) cols,
matrix_cols_nat m1 cols -> matrix_cols_nat m2 cols ->
matrix_cols_nat (map2 (map2 op) m1 m2) cols.
Proof.
induction m1; destruct m2; simpl; intros.
constructor.
constructor.
constructor.
inv H.
inv H0.
unfold map2 at 1.
unfold combine; fold (combine m1 m2).
simpl.
constructor; auto.
unfold map2. rewrite length_map.
rewrite length_combine; lia.
apply IHm1; auto.
Qed.
Lemma matrix_cols_nat_matrix_unop:
forall {t} (op: ftype t -> ftype t) (m: matrix t) cols,
matrix_cols_nat m cols ->
matrix_cols_nat (map (map op) m) cols.
Proof.
induction 1.
constructor.
constructor.
rewrite length_map. auto.
apply IHForall.
Qed.
Lemma matrix_cols_nat_remove_diag: forall {t} (m: matrix t),
matrix_cols_nat m (matrix_rows_nat m) ->
matrix_cols_nat (remove_diag m) (matrix_rows_nat m).
Proof.
intros.
apply matrix_by_index_cols.
Qed.
Local Open Scope nat.
Lemma matrix_index_diag:
forall {t} (d: diagmatrix t) i j,
i < length d -> j < length d ->
matrix_index (matrix_of_diag d) i j =
if (Nat.eq_dec i j) then nth i d (Zconst t 0) else Zconst t 0.
Proof.
intros.
unfold matrix_of_diag.
apply matrix_by_index_index; auto.
Qed.
Lemma binop_matrix_index:
forall {t} (f: ftype t -> ftype t -> ftype t)
(m1 m2: matrix t) cols,
matrix_rows_nat m1 = matrix_rows_nat m2 ->
matrix_cols_nat m1 cols -> matrix_cols_nat m2 cols ->
forall i j, i < matrix_rows_nat m1 -> j < cols ->
matrix_index (map2 (map2 f) m1 m2) i j =
f (matrix_index m1 i j) (matrix_index m2 i j).
Proof.
unfold matrix_rows_nat.
induction m1; destruct m2; simpl; intros; inv H.
simpl in H2. lia.
inv H0.
inv H1.
destruct i.
unfold matrix_index. simpl.
unfold map2.
clear - H3 H4.
revert j H3 l H4; induction a; destruct l,j; simpl; intros; inv H4; auto.
simpl in H3; lia. simpl in H3; lia.
simpl in H3. apply IHa; auto. lia.
apply (IHm1 m2 (length a)); auto.
lia.
Qed.
Lemma remove_plus_diag: forall {t} (m: matrix t),
matrix_cols_nat m (matrix_rows_nat m) ->
Forall (Forall finite) m ->
Forall2 (Forall2 feq) (matrix_add (matrix_of_diag (diag_of_matrix m)) (remove_diag m)) m.
Proof.
intros.
apply matrix_extensionality with (cols := matrix_rows_nat m); auto.
unfold matrix_add.
rewrite matrix_rows_nat_matrix_binop.
unfold matrix_of_diag.
rewrite matrix_by_index_rows.
apply length_diag_of_matrix; auto.
unfold matrix_of_diag.
rewrite matrix_by_index_rows.
rewrite length_diag_of_matrix; auto.
unfold remove_diag.
rewrite matrix_by_index_rows; auto.
apply matrix_cols_nat_matrix_binop.
replace (matrix_rows_nat m) with (length (diag_of_matrix m)).
apply matrix_by_index_cols.
apply length_diag_of_matrix; auto.
apply matrix_by_index_cols.
unfold matrix_add at 1.
rewrite matrix_rows_nat_matrix_binop.
2:{ unfold matrix_of_diag. rewrite matrix_by_index_rows.
unfold remove_diag. rewrite matrix_by_index_rows.
apply length_diag_of_matrix; auto.
}
unfold matrix_of_diag at 1.
rewrite matrix_by_index_rows; auto.
rewrite length_diag_of_matrix; auto.
intros.
unfold matrix_add.
rewrite binop_matrix_index with (cols := matrix_rows_nat m); auto.
unfold matrix_of_diag, remove_diag.
rewrite !matrix_by_index_index; auto.
destruct (Nat.eq_dec i j).
unfold diag_of_matrix.
rewrite nth_map_seq; auto.
subst j.
apply BPLUS_0_r.
eapply matrix_index_prop; eauto.
apply BPLUS_0_l.
eapply matrix_index_prop; eauto.
rewrite length_diag_of_matrix; auto.
rewrite length_diag_of_matrix; auto.
unfold matrix_of_diag, remove_diag.
rewrite !matrix_by_index_rows; auto.
apply length_diag_of_matrix; auto.
replace (matrix_rows_nat m) with (length (diag_of_matrix m)).
apply matrix_by_index_cols.
apply length_diag_of_matrix; auto.
apply matrix_by_index_cols.
unfold matrix_of_diag.
rewrite matrix_by_index_rows; auto.
rewrite length_diag_of_matrix; auto.
Qed.
End WITH_NANS.
End Experiment.