-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathCGH.py
114 lines (89 loc) · 3.9 KB
/
CGH.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Gradient descent CGH algorithm implemented for 2D/3D supervision.
Any questions about the code can be addressed to Manu Gopakumar
This code and data is released under the Creative Commons
Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
# The license is only for non-commercial use (commercial licenses can be
obtained from Stanford).
# The material is provided as-is, with no warranties whatsoever.
# If you publish any code, data, or scientific work based on this, please
cite our work.
Technical Paper:
Full-colour 3D holographic augmented-reality displays with metasurface
waveguides
Citation:
Gopakumar, M. et al. Full-colour 3D holographic augmented-reality displays
with metasurface waveguides. Nature (2024).
"""
import torch
import torch.nn as nn
import torch.optim as optim
import utils
def compute_scaled_loss(recon_field, target_amp, roi_res, loss_fn):
"""
Scale reconstructed field brightness by global scale factor
before computing loss
Input
-----
:param recon_field: reconstructed field
:param target_amp: target scene
:param roi_res: resolution of region of interest to optimize
:param loss_fn: loss function to optimize
Output
------
:return: loss computed on scaled reconstruction
"""
recon_amp = utils.crop_image(recon_field, roi_res).abs()
# Compute scale that minimizes MSE btw recon and target
with torch.no_grad():
s = (recon_amp * target_amp).mean() / \
(recon_amp ** 2).mean()
# Compute loss on scaled reconstruction
return loss_fn(s * recon_amp, target_amp)
def gradient_descent(init_phase, target_amp, forward_prop=None, num_iters=1000,
roi_res=None, lr=0.01, mem_eff=False, *args, **kwargs):
"""
Gradient-descent based method for phase optimization.
Input
-----
:param init_phase: initial phase for gradient descent iterations
:param target_amp: target scene
:param forward_prop: simulated propagation model
:param num_iters: number of optimization iterations
:param roi_res: resolution of region of interest to optimize
:param lr: learning rate for optimization
:param mem_eff: Option for 3D scenes to trade lower peak memory usage for
higher computational cost per iteration
Output
------
:return: phase pattern optimized to produce the desired scene
"""
# Initialize optimization variables and optimizer
slm_phase = init_phase.requires_grad_(True)
optvars = [{'params': slm_phase}]
optimizer = optim.Adam(optvars, lr=lr)
if roi_res is not None:
target_amp = utils.crop_image(target_amp, roi_res)
loss_fn = nn.functional.mse_loss
# Iteratively update phase to improve simulated reconstruction quality
for t in range(num_iters):
print(f'Iter {t}/{num_iters}')
optimizer.zero_grad()
# Simulate output of AR system and compute loss against desired output
if mem_eff:
# For 3D, optionally compute the gradient contribution for each
# depth independently to reduce peak memory usage
for depth_idx in range(target_amp.shape[1]):
recon_field = forward_prop(slm_phase, plane_idx=depth_idx)
loss_val = compute_scaled_loss(recon_field,
target_amp[:,depth_idx:depth_idx+1,...], roi_res, loss_fn)
loss_val.backward(retain_graph=False)
else:
recon_field = forward_prop(slm_phase)
loss_val = compute_scaled_loss(recon_field, target_amp,
roi_res, loss_fn)
loss_val.backward()
# Iteratively update phase based on the loss
optimizer.step()
return slm_phase.clone().cpu().detach()