# Compute the voxels coordinates
grid_coords = get_grid_coords(
[voxels.shape[0], voxels.shape[1], voxels.shape[2]], voxel_size
)
# Attach the predicted class to every voxel
grid_coords = np.vstack([grid_coords.T, voxels.reshape(-1)]).T
# Remove empty and unknown voxels
occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)]
# Draw occupied voxels
plt_plot = mlab.points3d(
occupied_voxels[:, 0],
occupied_voxels[:, 1],
occupied_voxels[:, 2],
occupied_voxels[:, 3],
colormap="viridis",
scale_factor=voxel_size - 0.5 * voxel_size,
mode="cube",
opacity=1.0,
vmin=0,
vmax=12,
)
colors = np.array(
[
[100, 150, 245, 255],
[100, 230, 245, 255],
[30, 60, 150, 255],
[80, 30, 180, 255],
[100, 80, 250, 255],
[255, 30, 30, 255],
[255, 40, 200, 255],
[150, 30, 90, 255],
[255, 0, 255, 255],
[255, 150, 255, 255],
[75, 0, 75, 255],
[175, 0, 75, 255],
[255, 200, 0, 255],
[255, 120, 50, 255],
[0, 175, 0, 255],
[135, 60, 0, 255],
[150, 240, 80, 255],
[255, 240, 150, 255],
[255, 0, 0, 255],
]
).astype(np.uint8)
plt_plot.glyph.scale_mode = "scale_by_vector"
plt_plot.module_manager.scalar_lut_manager.lut.table = colors
mlab.show()
mlab.savefig(filename=save_path)