Skip to content

Commit 98d0858

Browse files
add mesh texture
1 parent a01227e commit 98d0858

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

extract_mesh.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,27 @@
1414
from utils.tetmesh import marching_tetrahedra
1515

1616
@torch.no_grad()
17-
def evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size):
17+
def evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size, return_color=False):
1818
final_alpha = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda")
19+
if return_color:
20+
final_color = torch.ones((points.shape[0], 3), dtype=torch.float32, device="cuda")
1921

2022
with torch.no_grad():
2123
for _, view in enumerate(tqdm(views, desc="Rendering progress")):
2224
ret = integrate(points, view, gaussians, pipeline, background, kernel_size=kernel_size)
2325
alpha_integrated = ret["alpha_integrated"]
26+
if return_color:
27+
color_integrated = ret["color_integrated"]
28+
final_color = torch.where((alpha_integrated < final_alpha).reshape(-1, 1), color_integrated, final_color)
2429
final_alpha = torch.min(final_alpha, alpha_integrated)
30+
2531
alpha = 1 - final_alpha
32+
if return_color:
33+
return alpha, final_color
2634
return alpha
2735

2836
@torch.no_grad()
29-
def marching_tetrahedra_with_binary_search(model_path, name, iteration, views, gaussians, pipeline, background, kernel_size):
37+
def marching_tetrahedra_with_binary_search(model_path, name, iteration, views, gaussians, pipeline, background, kernel_size, filter_mesh : bool, texture_mesh : bool):
3038
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "fusion")
3139

3240
makedirs(render_path, exist_ok=True)
@@ -95,13 +103,19 @@ def alpha_to_sdf(alpha):
95103
if step not in [7]:
96104
continue
97105

98-
mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, process=False)
106+
if texture_mesh:
107+
_, color = evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size, return_color=True)
108+
vertex_colors=(color.cpu().numpy() * 255).astype(np.uint8)
109+
else:
110+
vertex_colors=None
111+
mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, vertex_colors=vertex_colors, process=False)
99112

100113
# filter
101-
mask = (distance <= scale).cpu().numpy()
102-
face_mask = mask[faces].all(axis=1)
103-
mesh.update_vertices(mask)
104-
mesh.update_faces(face_mask)
114+
if filter_mesh:
115+
mask = (distance <= scale).cpu().numpy()
116+
face_mask = mask[faces].all(axis=1)
117+
mesh.update_vertices(mask)
118+
mesh.update_faces(face_mask)
105119

106120
mesh.export(os.path.join(render_path, f"mesh_binary_search_{step}.ply"))
107121

@@ -112,7 +126,7 @@ def alpha_to_sdf(alpha):
112126
# mesh.export(os.path.join(render_path, f"mesh_binary_search_interp.ply"))
113127

114128

115-
def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelineParams):
129+
def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelineParams, filter_mesh : bool, texture_mesh : bool):
116130
with torch.no_grad():
117131
gaussians = GaussianModel(dataset.sh_degree)
118132
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
@@ -124,7 +138,7 @@ def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelinePara
124138
kernel_size = dataset.kernel_size
125139

126140
cams = scene.getTrainCameras()
127-
marching_tetrahedra_with_binary_search(dataset.model_path, "test", iteration, cams, gaussians, pipeline, background, kernel_size)
141+
marching_tetrahedra_with_binary_search(dataset.model_path, "test", iteration, cams, gaussians, pipeline, background, kernel_size, filter_mesh, texture_mesh)
128142

129143
if __name__ == "__main__":
130144
# Set up command line argument parser
@@ -133,6 +147,8 @@ def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelinePara
133147
pipeline = PipelineParams(parser)
134148
parser.add_argument("--iteration", default=30000, type=int)
135149
parser.add_argument("--quiet", action="store_true")
150+
parser.add_argument("--filter_mesh", action="store_true")
151+
parser.add_argument("--texture_mesh", action="store_true")
136152
args = get_combined_args(parser)
137153
print("Rendering " + args.model_path)
138154

@@ -141,4 +157,4 @@ def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelinePara
141157
torch.manual_seed(0)
142158
torch.cuda.set_device(torch.device("cuda:0"))
143159

144-
extract_mesh(model.extract(args), args.iteration, pipeline.extract(args))
160+
extract_mesh(model.extract(args), args.iteration, pipeline.extract(args), args.filter_mesh, args.texture_mesh)

0 commit comments

Comments
 (0)