local graph = require "./graph" type Node = graph.Node type SourceNode = graph.SourceNode local create_node = graph.create_node local create_source_node = graph.create_source_node local assert_stable_scope = graph.assert_stable_scope local evaluate_node = graph.evaluate_node local update_descendants = graph.update_descendants local push_scope_as_child_of = graph.push_scope_as_child_of local UPDATE_RATE = 120 local TOLERANCE_FACTOR = 10_000 type Animatable = number | CFrame | Color3 | UDim | UDim2 | Vector2 | Vector3 --[[ Unsupported datatypes: - bool - Vector2int16 - Vector3int16 - EnumItem ]] type SpringState = { k: number, -- spring constant c: number, -- damping coeff x0_123: vector, x0_456: vector, -- initial position x_123: vector, x_456: vector, -- current position x1_123: vector, x1_456: vector, -- target position v_123: vector, v_456: vector, -- current velocity source_value: T -- current value of spring input source } type SpringSettings = ({ position: T?, velocity: T?, impulse: T? }) -> () type TypeToVec6 = (T) -> (vector, vector) type Vec6ToType = (vector, vector) -> T local type_to_vec6 = { number = function(v) return vector.create(v, 0, 0), vector.zero end :: TypeToVec6, CFrame = function(v) return v.Position, vector.create(v:ToEulerAnglesXYZ()) end :: TypeToVec6, Color3 = function(v) -- todo: hsv, oklab? return vector.create(v.R, v.G, v.B), vector.zero end :: TypeToVec6, UDim = function(v) return vector.create(v.Scale, v.Offset, 0), vector.zero end :: TypeToVec6, UDim2 = function(v) return vector.create(v.X.Scale, v.X.Offset, v.Y.Scale), vector.create(v.Y.Offset, 0, 0) end :: TypeToVec6, Vector2 = function(v) return vector.create(v.X, v.Y, 0), vector.zero end :: TypeToVec6, Vector3 = function(v) return v, vector.zero end :: TypeToVec6, Rect = function(v) return vector.create(v.Min.X, v.Min.Y, v.Max.X), vector.create(v.Max.Y, 0, 0) end :: TypeToVec6, table = function(v) return vector.create(v[1] or 0, v[2] or 0, v[3] or 0), vector.create(v[4] or 0, 0, 0) end :: TypeToVec6<{ number }> } local vec6_to_type = { number = function(a, b) return a.X end :: Vec6ToType, CFrame = function(a, b) return CFrame.new(a) * CFrame.fromEulerAnglesXYZ(b.X, b.Y, b.Z) end :: Vec6ToType, Color3 = function(v) return Color3.new(math.clamp(v.X, 0, 1), math.clamp(v.Y, 0, 1), math.clamp(v.Z, 0, 1)) end :: Vec6ToType, UDim = function(v) return UDim.new(v.X, math.round(v.Y)) end :: Vec6ToType, UDim2 = function(a, b) return UDim2.new(a.X, math.round(a.Y), a.Z, math.round(b.X)) end :: Vec6ToType, Vector2 = function(v) return Vector2.new(v.X, v.Y) end :: Vec6ToType, Vector3 = function(v) return v end :: Vec6ToType, Rect = function(a, b) return Rect.new(a.X, a.Y, a.Z, b.X) end :: Vec6ToType, table = function(a, b) return { a.X, a.Y, a.Z, b.X } end :: Vec6ToType<{ number }> } local invalid_type = { __index = function(_, t: string) error(`cannot spring type {t}`, 0) end } setmetatable(type_to_vec6, invalid_type) setmetatable(vec6_to_type, invalid_type) -- maps spring data to its corresponding output node -- lifetime of spring data is tied to output node local springs: { [SpringState]: SourceNode } = {} setmetatable(springs :: any, { __mode = "v" }) local function spring(source: () -> T, period: number?, damping_ratio: number?): (() -> T, SpringSettings) local owner = assert_stable_scope() -- https://en.wikipedia.org/wiki/Damping local w_n = 2*math.pi / (period or 1) local z = damping_ratio or 1 local k = w_n^2 local c_c = 2*w_n local c = z * c_c -- todo: is there a solution other than reducing step size? -- todo: this does not catch all solver exploding cases if c > UPDATE_RATE*2 then -- solver will explode if this is true error("spring damping too high, consider reducing damping or increasing period", 0) end local data: SpringState = { k = k, c = c, x0_123 = vector.zero, x_123 = vector.zero, x1_123 = vector.zero, v_123 = vector.zero, x0_456 = vector.zero, x_456 = vector.zero, x1_456 = vector.zero, v_456 = vector.zero, source_value = false :: any, } local output = create_source_node(false :: any) local function updater_effect() local value = source() data.x1_123, data.x1_456 = type_to_vec6[typeof(value)](value) data.source_value = value springs[data] = output return value end local updater = create_node(owner, updater_effect, false :: any) evaluate_node(updater) -- set initial position to goal data.x_123, data.x_456 = data.x1_123, data.x1_456 -- set output to goal output.cache = data.source_value local config = function(p) local x = p.position local v = p.velocity local dv = p.impulse if x then local x_123, x_456 = type_to_vec6[typeof(x)](x) data.x_123, data.x_456 = x_123, x_456 data.x0_123, data.x0_456 = x_123, x_456 end if v then data.v_123, data.v_456 = type_to_vec6[typeof(v)](v) end if dv then local dv_123, dv_456 = type_to_vec6[typeof(dv)](dv) data.v_123 += dv_123 data.v_456 += dv_456 end -- schedule spring springs[data] = output end :: SpringSettings return function(...) if select("#", ...) == 0 then -- no args were given push_scope_as_child_of(output) return output.cache end -- set current position to value local v = ... :: T data.x_123, data.x_456 = type_to_vec6[typeof(v)](v) -- reset velocity data.v_123 = vector.zero data.v_456 = vector.zero -- schedule spring springs[data] = output -- set output to value output.cache = v return v end, config end -- calculates a float tolerance, based on the magnitude of the float local function get_min_step(x: number): number return x/TOLERANCE_FACTOR end local function get_min_vector_step(direction: vector): vector return vector.create( get_min_step(direction.x), get_min_step(direction.y), get_min_step(direction.z) ) end local function step_springs(dt: number) for s in springs do local k = s.k local c = s.c local x_123, x_456 = s.x_123, s.x_456 local x1_123, x1_456 = s.x1_123, s.x1_456 local u_123, u_456 = s.v_123, s.v_456 -- calculate displacement from target local dx_123 = x_123 - x1_123 local dx_456 = x_456 - x1_456 -- calculate spring force local fs_123 = dx_123*-k local fs_456 = dx_456*-k -- calculate friction force local ff_123 = u_123*-c local ff_456 = u_456*-c -- calculate acceleration local a_123 = (fs_123 + ff_123) local a_456 = (fs_456 + ff_456) -- step acceleration local v_123 = u_123 + a_123*dt local v_456 = u_456 + a_456*dt -- step velocity local y_123 = x_123 + v_123*dt local y_456 = x_456 + v_456*dt s.x_123, s.x_456 = y_123, y_456 s.v_123, s.v_456 = v_123, v_456 end end local function update_spring_sources() for data, output in springs do local x0_123, x0_456 = data.x0_123, data.x0_456 local x_123, x_456 = data.x_123, data.x_456 local x1_123, x1_456 = data.x1_123, data.x1_456 local v_123, v_456 = data.v_123, data.v_456 local tol_123 = vector.abs(get_min_vector_step(x0_123 - x1_123)) local tol_456 = vector.abs(get_min_vector_step(x0_456 - x1_456)) if -- position is at goal (within tolerance) vector.max(vector.abs(x_123 - x1_123), tol_123) == tol_123 and vector.max(vector.abs(x_456 - x1_456), tol_456) == tol_456 -- velocity is at 0 (within tolerance) and vector.max(vector.abs(v_123/10), tol_123) == tol_123 and vector.max(vector.abs(v_456/10), tol_456) == tol_456 then springs[data] = nil output.cache = data.source_value else output.cache = vec6_to_type[typeof(data.source_value)](x_123, x_456) end update_descendants(output) end end return function() local time_elapsed = 0 return spring, function(dt: number) time_elapsed += dt while time_elapsed > 1 / UPDATE_RATE do time_elapsed -= 1 / UPDATE_RATE step_springs(1 / UPDATE_RATE) end update_spring_sources() end end