diff --git a/main.py b/main.py index 0e06207..5c63faf 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,8 @@ from numpy.typing import NDArray from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components +from scipy.optimize import curve_fit, OptimizeWarning +from mpl_toolkits.mplot3d import Axes3D from typing import Tuple cli = typer.Typer(add_completion=False) @@ -76,11 +78,73 @@ def extract_orientations(df: pd.DataFrame, flip_z: bool = False) -> NDArray[np.f orientations_flipped[:, 0], orientations_flipped[:, 1], orientations_flipped[:, 2], length=arrow_scale, color='r', linewidth=0.5, arrow_length_ratio=0.1 ) + + # Plot a surface fit for each lattice + def select_best_fit(x, y, z, n): + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("error", OptimizeWarning) + try: + popt_x, cov_x = curve_fit(fit_func, (y, z), x) + x_pstd = np.sqrt((np.sum(np.power(np.sqrt(np.diag(cov_x)),2)))/n) + except OptimizeWarning as w: + x_pstd = 9999 + try: + popt_y, cov_y = curve_fit(fit_func, (x, z), y) + y_pstd = np.sqrt((np.sum(np.power(np.sqrt(np.diag(cov_y)),2)))/n) + except OptimizeWarning as w: + y_pstd = 9999 + try: + popt_z, cov_z = curve_fit(fit_func, (x, y), z) + z_pstd = np.sqrt((np.sum(np.power(np.sqrt(np.diag(cov_z)),2)))/n) + except OptimizeWarning as w: + z_pstd = 9999 + + if x_pstd < y_pstd and x_pstd < z_pstd: + return 'x', popt_x + elif y_pstd < x_pstd and y_pstd < z_pstd: + return 'y', popt_y + elif z_pstd < x_pstd and z_pstd < y_pstd: + return 'z', popt_z + else: + return '', np.full(6, np.nan) + + + for i, lattice in enumerate(unique_lattices): + good_particles_in_lattice=particles_good[particles_good['lattice']==lattice] + lattice_positions = good_particles_in_lattice[['tx', 'ty', 'tz']].to_numpy() + positions_norm = (lattice_positions - positions_min) / positions_range + n_particles_in_lattice = positions_norm.shape[0] + x, y, z = positions_norm.T[0], positions_norm.T[1], positions_norm.T[2] + best_fit, popt = select_best_fit(x, y, z, n_particles_in_lattice) + if best_fit == 'x': + y_range = np.linspace(min(y), max(y), 50) + z_range = np.linspace(min(z), max(z), 50) + Y, Z = np.meshgrid(y_range, z_range) + X = fit_func((Y, Z), *popt) + elif best_fit == 'y': + x_range = np.linspace(min(x), max(x), 50) + z_range = np.linspace(min(z), max(z), 50) + X, Z = np.meshgrid(x_range, z_range) + Y = fit_func((X, Z), *popt) + elif best_fit == 'z': + x_range = np.linspace(min(x), max(x), 50) + y_range = np.linspace(min(y), max(y), 50) + X, Y = np.meshgrid(x_range, y_range) + Z = fit_func((X, Y), *popt) + else: + print(f'Could not fit a surface for lattice {lattice}, skipping...') + continue + if i==0: + ax.plot_surface(X, Y, Z, color='red', alpha=0.2, label='surface fit') + else: + ax.plot_surface(X, Y, Z, color='red', alpha=0.2) # labels and view ax.set_xlabel('x') ax.set_ylabel('y') ax.set_zlabel('z') + ax.set_zlim(0,1) ax.legend() ax.set_title(f'{filename.stem}') ax.set_box_aspect([1, 1, 1]) # equal aspect ratio @@ -89,6 +153,11 @@ def extract_orientations(df: pd.DataFrame, flip_z: bool = False) -> NDArray[np.f plt.show() +def fit_func(xy, a, b, c, d, e, f): + x, y = xy + return a*x**2 + b*y**2 + c*x*y + d*x + e*y + f + + def flip_particles( df: pd.DataFrame, positions: NDArray[np.float64], @@ -221,7 +290,7 @@ def clean_particles( # use an empty column to store the lattice id for plotting if plot: df['lattice'] = particles_lattice_id - + if orientation_flip: flip_particles( df,