-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaaa.py
More file actions
144 lines (100 loc) · 6.48 KB
/
aaa.py
File metadata and controls
144 lines (100 loc) · 6.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
#from jax.scipy.optimize import fsolve
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random, device_put
KEY = random.PRNGKey(0)
# Global variables/parameters (will have to be passed to the functions anyways)
# Discrete/continuum boundary
DISCRETE_CONTINUUM_BOUNDARY = 10
# Define the discrete energy levels
DISCRETE_ENERGIES = jnp.array([0, 3, 5, 7, 8, 9, 9.5, 10])
DISCRETE_LEVEL_NUMBER = len(DISCRETE_ENERGIES)
# And the transition strengths between them
DCW = random.uniform(KEY, shape=(DISCRETE_LEVEL_NUMBER, DISCRETE_LEVEL_NUMBER))
DISCRETE_DECAY_WIDTHS = DCW + DCW.T
# Diagonal must be 0
DCW = jnp.where(jnp.eye(DISCRETE_LEVEL_NUMBER), 0, DISCRETE_LEVEL_NUMBER)
# Define the continuum energy levels via an event density function
# Backshifted Fermi Gas
def rho_f(energy, discrete_continuum_boundary):
return 1/(1 + jnp.exp((energy - discrete_continuum_boundary)/1.1))
def rho_0(energy, disp_parameter):
return disp_parameter
def level_density(energy, discrete_continuum_boundary, disp_parameter):
return 100 * (1/rho_f(energy, discrete_continuum_boundary) + 1/rho_0(energy, disp_parameter))**(-1)
# Define the continuum transition strengths
# Just a sine wave lmao (transition strenght should be 0 for E_gamma = 0 and smooth at 0)
def transition_strength(gamma_energy):
ts = jnp.sin(gamma_energy)**2 * 5 * jnp.exp(-gamma_energy/10)
ts = jnp.where(gamma_energy < 0, 0, ts)
return ts
# (differential) decay width
def differential_decay_width(final_energy, initial_energy, discrete_continuum_boundary, disp_parameter):
gamma_energy = initial_energy - final_energy
return level_density(final_energy, discrete_continuum_boundary, disp_parameter) * transition_strength(gamma_energy)
# Numpy versions for sampling ---------------------------------------------
# CDF of the differential decay width (NUMPY VERSION)
# TODO: Consider precomputing the CDF (at least the norm) for faster sampling
def cdf_differential_decay_width(final_energy, initial_energy, discrete_continuum_boundary, disp_parameter):
gamma_energy = initial_energy - final_energy
energies = np.linspace(discrete_continuum_boundary, final_energy, 5000)
full_energies = np.linspace(discrete_continuum_boundary, initial_energy, 5000)
cdf_val = np.trapz(np.array(differential_decay_width(energies, initial_energy, discrete_continuum_boundary, disp_parameter)), energies, axis=0)
cdf_norm = np.trapz(np.array(differential_decay_width(full_energies, initial_energy, discrete_continuum_boundary, disp_parameter)), full_energies, axis=0)
#print(cdf_val, cdf_norm)
# Normalize
cdf = cdf_val / cdf_norm
return cdf
def decay_width_to_discrete(initial_energy, discrete_energies, discrete_continuum_boundary, disp_parameter):
discrete_level_number = len(discrete_energies)
total_decay_width_to_discrete = np.sum(transition_strength(initial_energy - discrete_energies))
return total_decay_width_to_discrete
# Inverse CDF of the differential decay width (for sampling) (NUMPY VERSION)
def inverse_cdf_differential_decay_width(cdf_value, initial_energy, discrete_continuum_boundary, disp_parameter):
fun = lambda final_energy:cdf_differential_decay_width(final_energy, initial_energy, discrete_continuum_boundary, disp_parameter) - cdf_value
# Initial guess
x0 = jnp.array([0.5 * (initial_energy + discrete_continuum_boundary)])
#print(x0)
# Find the root (i.e. the inverse CDF value)
root_result = fsolve(fun, x0)
return root_result[0]
# Inverse CDF contemplating the possibility of the decay to a discrete level (NUMPY VERSION)
def spicy_inverse_cdf_differential_decay_width(cdf_value, initial_energy, discrete_energies, discrete_continuum_boundary, disp_parameter):
full_energies = np.linspace(discrete_continuum_boundary, initial_energy, 5000)
cdf_norm = np.trapz(np.array(differential_decay_width(full_energies, initial_energy, discrete_continuum_boundary, disp_parameter)), full_energies, axis=0)
total_decay_width_to_discrete = np.sum(transition_strength(initial_energy - discrete_energies))
stay_in_continuum_probability = cdf_norm / (cdf_norm + total_decay_width_to_discrete)
go_to_discrete_probability = total_decay_width_to_discrete / (cdf_norm + total_decay_width_to_discrete)
# The "continuum cut" is the value of the CDF that separates the discrete and continuum parts.
# As we're sampling from a uniform distribution, this is just the probability of staying in the continuum
continuum_cut = stay_in_continuum_probability
if cdf_value <= continuum_cut:
root = inverse_cdf_differential_decay_width(cdf_value/continuum_cut, initial_energy, discrete_continuum_boundary, disp_parameter)
else:
root = -1
return root, continuum_cut
# -------------------------------------------------------------------------
# Jax versions for gradient computation ------------------------------------------------
def jax_cdf_differential_decay_width(cdf_norm, final_energy, initial_energy, discrete_continuum_boundary, disp_parameter):
energies = jnp.linspace(discrete_continuum_boundary, final_energy, 5000)
cdf_val = jnp.trapz(differential_decay_width(energies, initial_energy, discrete_continuum_boundary, disp_parameter), energies, axis=0)
# Normalize
cdf = cdf_val / cdf_norm
return cdf
# This norm is the total decay width to continuum states
def jax_cdf_norm(initial_energy, discrete_continuum_boundary, disp_parameter):
full_energies = jnp.linspace(discrete_continuum_boundary, initial_energy, 5000)
cdf_norm = jnp.trapz(differential_decay_width(full_energies, initial_energy, discrete_continuum_boundary, disp_parameter), full_energies, axis=0)
return cdf_norm
def jax_total_decay_width_to_discrete(initial_energy, discrete_energies):
return jnp.sum(transition_strength(initial_energy - discrete_energies))
# This computes the probability of staying in the continuum for the next energy at a given initial energy
def jax_continuum_cut(initial_energy, discrete_energies, discrete_continuum_boundary, disp_parameter):
cdf_norm = jax_cdf_norm(initial_energy, discrete_continuum_boundary, disp_parameter)
total_decay_width_to_discrete = jax_total_decay_width_to_discrete(initial_energy, discrete_energies)
stay_in_continuum_probability = cdf_norm / (cdf_norm + total_decay_width_to_discrete)
return stay_in_continuum_probability
# -------------------------------------------------------------------------