<!-- Export default values for use in other components -->
<script lang="ts">
  import { format_num } from '../labels'
  import type { Vec3 } from '../math'
  import * as math from '../math'
  import { DEFAULTS } from '../settings'
  import { CanvasTooltip } from './'
  import { T } from '@threlte/core'
  import {
    BoxGeometry,
    EdgesGeometry,
    Euler,
    Matrix4,
    Quaternion,
    Vector3,
  } from 'three'

  let {
    matrix = undefined,
    cell_edge_color = DEFAULTS.structure.cell_edge_color,
    cell_surface_color = DEFAULTS.structure.cell_surface_color,
    cell_edge_width = DEFAULTS.structure.cell_edge_width,
    cell_edge_opacity = DEFAULTS.structure.cell_edge_opacity,
    cell_surface_opacity = DEFAULTS.structure.cell_surface_opacity,
    show_cell_vectors = true,
    vector_colors = [`red`, `green`, `blue`],
    vector_origin = [-1, -1, -1] satisfies Vec3,
    float_fmt = `.2f`,
  }: {
    matrix?: math.Matrix3x3
    cell_edge_color?: string
    cell_surface_color?: string
    cell_edge_width?: number // thickness of the cell edges
    cell_edge_opacity?: number // opacity of the cell edges
    cell_surface_opacity?: number // opacity of the cell surfaces
    show_cell_vectors?: boolean // whether to show the lattice vectors
    vector_colors?: readonly [string, string, string] // lattice vector colors
    vector_origin?: Vec3 // lattice vector origin (all arrows start from this point)
    float_fmt?: string
  } = $props()

  let hovered_idx = $state<number | null>(null) // track hovered vector
  let lattice_center = $derived(
    matrix
      ? (math.scale(math.add(...matrix), 0.5) satisfies Vec3)
      : ([0, 0, 0] satisfies Vec3),
  )

  // Extract line segments from EdgesGeometry for cylinder-based thick lines
  function get_edge_segments(edges_geometry: EdgesGeometry): [Vector3, Vector3][] {
    const positions = edges_geometry.getAttribute(`position`).array as Float32Array
    const segments: [Vector3, Vector3][] = []

    for (let idx = 0; idx < positions.length; idx += 6) {
      const start = new Vector3(
        positions[idx + 0],
        positions[idx + 1],
        positions[idx + 2],
      )
      const end = new Vector3(
        positions[idx + 3],
        positions[idx + 4],
        positions[idx + 5],
      )
      segments.push([start, end])
    }

    return segments
  }

  // Calculate cylinder transform for a line segment
  function get_cylinder_transform(
    start: Vector3,
    end: Vector3,
  ): { position: Vec3; rotation: Vec3; length: number } {
    const direction = end.clone().sub(start)
    const length = direction.length()
    const center = start.clone().add(end).multiplyScalar(0.5)

    if (length === 0) { // Zero-length: no rotation; render a degenerate cylinder
      return { position: center.toArray(), rotation: [0, 0, 0], length }
    }
    // Calculate rotation to align cylinder with the line (zero-length guarded above)
    const quaternion = new Quaternion().setFromUnitVectors(
      new Vector3(0, 1, 0),
      direction.normalize(),
    )
    const euler = new Euler().setFromQuaternion(quaternion)

    return {
      position: center.toArray(),
      rotation: euler.toArray().slice(0, 3) as Vec3,
      length,
    }
  }
</script>

{#if matrix}
  {#key matrix}
    {@const shear_matrix = new Matrix4().makeBasis(
    new Vector3(...matrix[0]),
    new Vector3(...matrix[1]),
    new Vector3(...matrix[2]),
  )}
    {@const box_geometry = new BoxGeometry(1, 1, 1).applyMatrix4(shear_matrix)}

    <!-- Render wireframe edges if edge opacity > 0 -->
    {#if cell_edge_opacity > 0}
      {@const edges_geometry = new EdgesGeometry(box_geometry)}
      {@const edge_segments = get_edge_segments(edges_geometry)}

      <!-- Use cylinders for thick wireframe lines -->
      <T.Group position={lattice_center}>
        {#each edge_segments as [start, end], idx (idx)}
          {@const { position, rotation, length } = get_cylinder_transform(start, end)}
          <T.Mesh {position} {rotation}>
            <T.CylinderGeometry
              args={[cell_edge_width * 0.01, cell_edge_width * 0.01, length, 8]}
            />
            <T.MeshStandardMaterial
              color={cell_edge_color}
              opacity={cell_edge_opacity}
              transparent
              depthWrite={false}
            />
          </T.Mesh>
        {/each}
      </T.Group>
    {/if}

    <!-- Render transparent surfaces if surface opacity > 0 -->
    {#if cell_surface_opacity > 0}
      <T.Mesh geometry={box_geometry} position={lattice_center}>
        <T.MeshStandardMaterial
          color={cell_surface_color}
          opacity={cell_surface_opacity}
          transparent
          depthWrite={false}
        />
      </T.Mesh>
    {/if}

    <!-- NOTE below is an untested fix for the lattice vectors being much too small when deployed even though they look correct in local dev -->
    {#if show_cell_vectors}
      <T.Group position={vector_origin}>
        {#each matrix as vec, idx (vec)}
          {@const shaft_length = Math.hypot(...vec) * 0.85}
          <!-- Shaft goes to 85% of vector length -->
          {@const tip_start_position = math.scale(vec, 0.85)}
          <!-- Calculate rotation to align with vector direction -->
          {@const quaternion = new Quaternion().setFromUnitVectors(
      new Vector3(0, 1, 0), // Default up direction for cylinder/cone
      new Vector3(...vec).normalize(),
    )}
          {@const rotation = new Euler()
      .setFromQuaternion(quaternion)
      .toArray()
      .slice(0, 3) as Vec3}
          <!-- Arrow shaft - position at center of shaft length -->
          {@const shaft_center = math.scale(vec, 0.425)}
          <!-- Center at 42.5% = half of 85% -->
          <T.Mesh
            position={shaft_center}
            {rotation}
            onpointerenter={() => hovered_idx = idx}
            onpointerleave={() => hovered_idx = null}
          >
            <T.CylinderGeometry args={[0.05, 0.05, shaft_length, 16]} />
            <T.MeshStandardMaterial color={vector_colors[idx]} />
          </T.Mesh>

          <!-- Arrow tip -->
          <T.Mesh
            position={tip_start_position}
            {rotation}
            onpointerenter={() => hovered_idx = idx}
            onpointerleave={() => hovered_idx = null}
          >
            <T.ConeGeometry args={[0.175, 0.5, 16]} />
            <T.MeshStandardMaterial color={vector_colors[idx]} />
          </T.Mesh>
        {/each}
      </T.Group>

      <!-- Tooltip for hovered vector -->
      {#if hovered_idx !== null && matrix}
        {@const hovered_vec = matrix[hovered_idx]}
        {@const tooltip_position = math.add(vector_origin, hovered_vec)}
        <CanvasTooltip position={tooltip_position}>
          <strong>{[`A`, `B`, `C`][hovered_idx]}</strong>
          ({hovered_vec.map((coord) => format_num(coord, float_fmt)).join(`, `)}) Å
        </CanvasTooltip>
      {/if}
    {/if}
  {/key}
{/if}
