# Watershed on a mesh

The notebook uses [trimesh](https://trimsh.org) and [higra](https://higra.readthedocs.io/en/stable/). Trimesh is a pure python library that deals with triangular mesh. It is a bit slow with large model, and under colab, does not have enough memory for larger model. Please have a look at the Higra example that uses the igl library for solving those issues.

We are going to read a mesh file, compute a curvature measure on the mesh, and compute a watershed on the dual graph of the mesh.

Reference paper:
> Jean Cousty,  Gilles Bertrand,  Michel Couprie,  Laurent Najman
> Collapses and watersheds in pseudomanifolds of arbitrary dimension.
> Journal of Mathematical Imaging and Vision volume 50, pages 261–285 (2014) [10.1007/s10851-014-0498-z](https://doi.org/10.1007/s10851-014-0498-z). [hal-00871498v2](https://hal.science/hal-00871498/)

For an application of the hierarchical  watershed on a mesh, you can look at:

> Sylvie Philipp-Foliguet, Michel M. Jordan, Laurent Najman, Jean Cousty. 
> Artwork 3D model database indexing and classification. 
> Pattern Recognition, 2011, 44 (3), pp.588-597.  [10.1016/j.patcog.2010.09.016](https://doi.org/10.1016/j.patcog.2010.09.016). [hal-00538470](https://hal.science/hal-00538470/)

In [1]:
# Trimesh is a simple python package for dealing with meshes.
!pip install trimesh[easy]
# WARNING: trimesh is slow for large meshes
# An alternative can be igl https://libigl.github.io/libigl-python-bindings/
# Also, plotly, the viewer we are using in this notebook, is not made for large mesh
# An alternative is meshplot https://skoch9.github.io/meshplot/
# Another Higra example uses igl and meshplot, please have a look

# Higra is for the watershed and related operators on graphs
!pip install higra

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting trimesh[easy]
  Downloading trimesh-3.18.1-py3-none-any.whl (670 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.9/670.9 KB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Collecting pyglet<2
  Downloading pyglet-1.5.27-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting svg.path
  Downloading svg.path-6.2-py2.py3-none-any.whl (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 KB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash
  Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 KB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting colorlog
  Downloading colorlog-6.7.0-py2.py3-none-any.w

In [2]:
import trimesh
# We need a curvature operator for weighting the mesh
from trimesh.curvature import discrete_gaussian_curvature_measure, discrete_mean_curvature_measure, sphere_ball_intersection


As a first example, we will work with a coarse bunny model from Standford repository. A copy of this model in obj format, preprocessed with [MeshFix](https://github.com/MarcoAttene/MeshFix-V2.1), is located at the following address

In [3]:
# This model is from pyvista, it works fine, but resolution is not so good
mesh = trimesh.load_remote('https://raw.githubusercontent.com/higra/Higra-Notebooks/master/data/BunnyCoarse.obj', force='mesh')


In [4]:
assert(mesh.is_watertight) # Notebook can only process watertight mesh for the watershed (normal, don't you think so?)

The mesh viewer in trimesh is less powerfull that plotly, hence we are going to use plotly. Here are some helper functions.

In [5]:
# Adapted from https://plotly.com/python/v3/surface-triangulation/
import matplotlib.cm as cm
from functools import reduce
import numpy as np
import plotly.graph_objects as go
import plotly

def map_z2color(zval, colormap, vmin, vmax):
    #map the normalized value zval to a corresponding color in the colormap

    if vmin>vmax:
        raise ValueError('incorrect relation between vmin and vmax')
    t=(zval-vmin)/float((vmax-vmin))#normalize val
    R, G, B, alpha=colormap(t)
    return 'rgb('+'{:d}'.format(int(R*255+0.5))+','+'{:d}'.format(int(G*255+0.5))+\
           ','+'{:d}'.format(int(B*255+0.5))+')'

# Plot a mesh, with or without the edges of the triangles
# If color is None, the mesh is plotted with a color depending on the height (z axis)
def plotly_trisurf(mesh, colors=None, colormap=cm.RdBu, plot_edges=False):
    # colors is defined on the vertices

    #x, y, z are lists of coordinates of the triangle vertices 
    #simplices are the simplices that define the triangularization;
    #simplices  is a numpy array of shape (no_triangles, 3)
    #insert here the  type check for input data
    x, y, z = mesh.vertices.T
    simplices = mesh.faces
    points3D=np.vstack((x,y,z)).T
    tri_vertices=list(map(lambda index: points3D[index], simplices))# vertices of the surface triangles 

    if colors is None:    
        zmean=[np.mean(tri[:,2]) for tri in tri_vertices]# mean values of z-coordinates of 
                                                          #triangle vertices
        min_zmean=np.min(zmean)
        max_zmean=np.max(zmean)
        facecolor=[map_z2color(zz,  colormap, min_zmean, max_zmean) for zz in zmean]
    else:
        zmean=[(colors[i]+colors[j]+colors[k])/3. for i,j,k in simplices] # Mean of the color of the vertices
        min_zmean=np.min(zmean)
        max_zmean=np.max(zmean)
        facecolor=[map_z2color(zz,  colormap, min_zmean, max_zmean) for zz in zmean]
    
    #I,J,K=tri_indices(simplices)
    I,J,K = mesh.faces.T

    triangles=go.Mesh3d(x=x,
                     y=y,
                     z=z,
                     facecolor=facecolor,
                     i=I,
                     j=J,
                     k=K,
                     name=''
                    )

    if plot_edges is False:# the triangle sides are not plotted 
        return [triangles]
    else:
        #define the lists Xe, Ye, Ze, of x, y, resp z coordinates of edge end points for each triangle
        #None separates data corresponding to two consecutive triangles
        lists_coord=[[[T[k%3][c] for k in range(4)]+[ None]   for T in tri_vertices]  for c in range(3)]
        Xe, Ye, Ze=[reduce(lambda x,y: x+y, lists_coord[k]) for k in range(3)]

        #define the lines to be plotted
        lines=go.Scatter3d(x=Xe,
                        y=Ye,
                        z=Ze,
                        mode='lines',
                        line=dict(color= 'rgb(50,50,50)', width=1.5)
               )
        return [triangles, lines]

# For ploting a saliency map
def plotly_trisurf_saliency(mesh, edges, saliency):

    # Remove saliency edges with 0 weight
    edges = edges[saliency>0]
    saliency = saliency[saliency>0]

    #x, y, z are lists of coordinates of the triangle vertices 
    #simplices are the simplices that define the triangularization;
    #simplices  is a numpy array of shape (no_triangles, 3)
    x, y, z = mesh.vertices.T
    simplices = mesh.faces
    points3D=np.vstack((x,y,z)).T
    tri_vertices=list(map(lambda index: points3D[index], simplices)) # vertices of the surface triangles 
    
    I,J,K = mesh.faces.T

    triangles = go.Mesh3d(x=x, y=y, z=z, 
                          i=I,j=J,k=K,
                          color='darkgray', opacity=1.,
                          name='')
    
    result = [triangles]
    from plotly.express.colors import sample_colorscale
    col = sample_colorscale('peach', list(np.linspace(0, 1, 256)))
    sal = saliency/saliency.max()
    uniq_saliency = np.unique(sal)
    for i in np.arange(uniq_saliency.shape[0]):
      edge_vertices = list(map(lambda index: points3D[index], edges[sal==uniq_saliency[i]])) # Vertices of the edge set
      #define the lists Xe, Ye, Ze, of x, y, resp z coordinates of edge end points for each edge
      #None separates data corresponding to two consecutive edges
      lists_coord = [[[e[k%2][c] for k in range(3)]+[ None]   for e in edge_vertices]  for c in range(3)]
      Xe, Ye, Ze=[reduce(lambda x,y: x+y, lists_coord[k]) for k in range(3)]

      #define the lines to be plotted
      lines=go.Scatter3d(x=Xe,
                      y=Ye,
                      z=Ze,
                      mode='lines',
                      line=dict(color= col[int(uniq_saliency[i]*255)], width=3.), #1.5
                      name = f'{(uniq_saliency[i]*saliency.max()):.2f}'
               )
      result.append(lines)
    return result


In [6]:
data1=plotly_trisurf(mesh, colormap=cm.RdBu, plot_edges=True)

In [7]:
#Set up the scene, all plot will share that setting
axis = dict(
showbackground=True,
backgroundcolor="rgb(230, 230,230)",
gridcolor="rgb(255, 255, 255)",
zerolinecolor="rgb(255, 255, 255)",
    )

layout = go.Layout(
         title='Mesh triangulation',
         width=800,
         height=800,
         scene=dict(
         xaxis=dict(axis),
         yaxis=dict(axis),
         zaxis=dict(axis),
        aspectratio=dict(
            x=1,
            y=1,
            z=1
        ),
        )
        )



In [8]:
fig1 = go.Figure(data=data1, layout=layout)
fig1.show()

In [9]:
# Choose a curvature, and compute it (on the vertices of the mesh)
curvature = discrete_mean_curvature_measure(mesh, mesh.vertices, .1) #/sphere_ball_intersection(1, .1)

In [10]:
#Plot data with curvature
data2=plotly_trisurf(mesh, colors=curvature, colormap=cm.RdBu, plot_edges=True)

In [11]:
fig2 = go.Figure(data=data2, layout=layout)
fig2.show()

In [12]:
# Compute the dual graph
adjacency, edges = trimesh.graph.face_adjacency(faces=None, mesh=mesh, return_edges=True)

In [13]:
#We need a curvature on each edge
edge_weights= (curvature[edges[:,0]] + curvature[edges[:,1]])/2.

In [14]:
import higra as hg

In [15]:
g = hg.UndirectedGraph(mesh.faces.shape[0]) # Number of points is the number of faces in the mesh
g.add_edges(adjacency[:,0].tolist(), adjacency[:,1].tolist())

In [16]:
# Compute the hierarchy
tree, altitudes = hg.watershed_hierarchy_by_area(g, edge_weights)

In [17]:
# Compute the saliency map
sal = hg.saliency(tree, altitudes)


In [18]:
# Plot it
# You can click on the legend to hide/show a level of the saliency map
data3=plotly_trisurf_saliency(mesh, edges=edges, saliency=sal)
fig3 = go.Figure(data=data3, layout=layout)
fig3.show()