#
# Based on:
#http://viola.informatik.uni-bremen.de/typo/html/qingteng/index.html
#
# Documentation :
# http://www.scratchapixel.com
#
# Thanks to Alexander Tsepkov the creator of RapydScript, and
# Charles Law for allowing testing RapydScript online
# Post comments and suggestions on RapydScript Google Group
#
# You can modify the script and see the modifications
# in direct live (quite)
# When modifying you can reduce WIDTH and HEIGHT to avoid lag
#
# Authored by Salvatore DI DIO (email : salvatore dot didio at gmail)
#
# Best performance with Firefox
# On my mac book pro : 1280 x 800 rendered in less than a minute

WIDTH = 440
HEIGHT = 240

class Vector:
    def __init__(self, x=0, y=0, z=0):
        self.x = x
        self.y = y
        self.z = z

    def copy(self):
        return Vector(self.x, self.y, self.z)
    
    def length(self):
        return Math.sqrt(self.x*self.x + self.y*self.y + self.z*self.z)
    
    def sqrLength(self):
        return (self.x*self.x + self.y*self.y + self.z*self.z)
    
    def normalize(self):
        inv = 1.0/self.length()
        return Vector(self.x*inv, self.y*inv,self.z*inv)
    
    def negate(self):
        return Vector(-self.x, -self.y, -self.z)
    
    def add(self,v):
        return Vector(self.x + v.x, self.y + v.y, self.z + v.z)
    
    def sub(self,v):
        return Vector(self.x - v.x, self.y - v.y, self.z - v.z)
    
    def times(self, k):
        return Vector(k*self.x, k*self.y, k*self.z)
    
    def dot(self,t):
        return self.x*t.x + self.y*t.y+self.z*t.z
    
    def cross(self,w):
        v = Vector()
        v.x =  self.y * w.z - self.z * w.y
        v.y = -self.x * w.z + self.z * w.x
        v.z = self.x * w.y - self.y * w.x
        return v
    
    @staticmethod    
    def zero():
        return Vector()

class Ray:
    def __init__(self, vectOrigin, vectDest):
        self.origin = vectOrigin
        self.direction = vectDest
    
    def getPoint(self, t):
        # p = self.origin + t * self.direction
        return self.origin.add(self.direction.times(t))

class IntersectionResult:
    def __init__(self):
        self.sceneobject = null
        self.distance = 0
        self.position = Vector.zero()
        self.normal = Vector.zero()
    
    @staticmethod
    def nohit():
        return IntersectionResult()
    
class Sphere:
    def __init__(self, center = Vector(0,2,-1), radius = 1):
        self.center = center
        self.radius = radius
        self.name = "Sphere"
    
    def initialize(self):
        self.sqrRadius = self.radius * self.radius
    
    def double(self):
        return 100
    
    def intersect(self, ray):    
        v = ray.origin.sub(self.center)
        a0 = v.sqrLength() - self.sqrRadius
        DdotV = ray.direction.dot(v)
        
        if DdotV <= 0: 
            discr = DdotV * DdotV - a0   
            if discr >= 0:
                r = IntersectionResult()
                r.sceneobject = self
                r.distance =  -DdotV - Math.sqrt(discr)
                r.position = ray.getPoint(r.distance)
                r.normal =  r.position.sub(self.center)
                return r
        return IntersectionResult.nohit()

class Camera:
    def __init__(self,eye, front, up, fov,aspectratio):
        self.eye = eye
        self.front = front 
        self.RefUp = up
        self.fov = fov
        self.aspectratio = aspectratio
    
    def initialize(self):
        self.right =  self.front.cross(self.RefUp)
        self.up = self.right.cross(self.front)
        self.fovScale = Math.tan(self.fov * 0.5 * Math.PI/180)*2
    
    def generateRay(self,x,y):
        r = self.right.times((x - 0.5) * self.fovScale*self.aspectratio)
        u = self.up.times((y - 0.5) * self.fovScale)
        ray = Ray(self.eye, self.front.add(r).add(u).normalize())
        return ray

class Union:
    def __init__(self,geometries):
        self.geometries = geometries
    
    def initialize(self):
        for geo in self.geometries:
            geo.initialize()
    
    def intersect(self, ray):
        minDistance = Infinity
        minResult = IntersectionResult()
        for geo in self.geometries:
            result = geo.intersect(ray)
            if ((result.sceneobject != null) and (result.distance < minDistance)):        
                minDistance = result.distance
                minResult = result
        return minResult

class Plane:
    def __init__(self, normal, distance):
        self.normal = normal
        self.d = distance
        self.name = "Plane"
    
    def initialize(self):
        self.position = self.normal.times(self.d)
    
    def intersect(self, ray):
        a =  ray.direction.dot(self.normal)
        if a >= 0:
            return IntersectionResult.nohit()
        
        b = self.normal.dot(ray.origin.sub(self.position))
        result = IntersectionResult()
        result.sceneobject = self
        result.distance = -b/a
        result.position =  ray.getPoint(result.distance)
        result.normal = self.normal
        return result

class Color:
    def __init__(self, r=0, g=0, b=0):
        self.r = r
        self.g = g
        self.b = b
    
    def add(self,c):
       return Color(self.r + c.r, self.g + c.g, self.b + c.b)

    def times(self, k):
        return Color(k*self.r, k*self.g, k*self.b)

    def modulate(self, c):
        return Color(self.r * c.r, self.g * c.g, self.b * c.b)

Color.black = Color()
Color.red = Color(1,0,0)
Color.white = Color(1,1,1)
Color.green = Color(0,1,0)
Color.blue = Color(0,0,1)
Color.gold = Color(1,0.8745098039215686,0)
Color.cyan = Color(0,1,1)

class CheckerMaterial:
    def __init__(self, scale, reflectiveness):
        self.scale = scale
        self.reflectiveness = reflectiveness
    
    def sample(self,ray,position, normal):
        v =  Math.abs((Math.floor(position.x * 0.1) \
           + Math.floor(position.z * self.scale)) % 2)
        if v < 1:
            return Color.black
        else:
            return Color.white

lightDir =  Vector(6, 8.5, 15).normalize()
lightColor = Color(1,1,1)

class PhongMaterial:
    def __init__(self,diffuse, specular, shininess, reflectiveness):
        self.diffuse = diffuse
        self.specular = specular
        self.shininess = shininess
        self.reflectiveness = reflectiveness
    
    def sample(self,ray,position,normal):
        NdotL = normal.dot(lightDir)
        H = (lightDir.sub(ray.direction)).normalize()
        NdotH = normal.dot(H)
        diffuseTerm = self.diffuse.times(Math.max(NdotL, 0))        
        specularTerm =  self.specular.times(Math.pow(Math.max(NdotH, 0), self.shininess))
        return lightColor.modulate(diffuseTerm.add(specularTerm))

def rayTraceRecursive(scene, ray, maxReflect):
    result = scene.intersect(ray)
    if result.sceneobject:
        reflectiveness = result.sceneobject.material.reflectiveness
        mat = result.sceneobject.material
        color = mat.sample(ray, result.position,result.normal)
        color.times(1 - reflectiveness)
        if (reflectiveness > 0 and maxReflect > 0):
            r = result.normal.times(-2 * result.normal.dot(ray.direction)).add(ray.direction)
            ray = Ray(result.position, r)
            reflectedColor = rayTraceRecursive(scene, ray, maxReflect - 1)
            color = color.add(reflectedColor.times(reflectiveness))
        return color
    return Color(0,0,0)

CHECKER_PATTERN = True
def main():
    canvas = document.getElementById("canvas")
    ctx = canvas.getContext("2d")
    canvas.width = WIDTH
    canvas.height = HEIGHT
    w  = canvas.width
    h = canvas.height
    ctx.fillStyle = "rgb(255,255,255)"
    ctx.fillRect(0, 0, w, h)
    
    aspectratio = w/h
    camera = Camera(Vector(0,0.56,8),Vector(0,0,-1),Vector(0,1,0),30,aspectratio)
    camera.initialize()
    
    sphere = Sphere(Vector(1.5,0,-1),1)
    sphere.material = PhongMaterial(Color.red, Color.white, 500,0.25)
    
    sphere2 = Sphere(Vector(-1.5,0,-1),1)
    sphere2.material = PhongMaterial(Color.gold, Color.white, 500,0.25)
    
    sphere3 = Sphere(Vector(0,0,-2),1)
    sphere3.material = PhongMaterial(Color.blue, Color.white, 500,0.25)
    
    plane = Plane(Vector(0,1,0),-1)
    if CHECKER_PATTERN:
        plane.material = CheckerMaterial(0.1,0.9);
    else:
        plane.material = PhongMaterial(Color.cyan, Color.white, 500,1.0)
    
    scene = Union([plane,sphere,sphere2,sphere3])
    scene.initialize()
    
    maxDepth = 20
    
    imgdata = ctx.createImageData(w,h)
    #start
    for y in range(h):
        for x in range(w):
            sy = 1- y/h
            sx = x/w
            index = (x + y * w) * 4
            ray = camera.generateRay(sx, sy)
            result = scene.intersect(ray)
            if (result.sceneobject != null):
                col =  rayTraceRecursive(scene, ray, 4)
                imgdata.data[index + 0] = col.r * 255
                imgdata.data[index + 1] = col.g * 255
                imgdata.data[index + 2] = col.b * 255
                imgdata.data[index + 3] = 255
            else:
                imgdata.data[index + 0] = 0
                imgdata.data[index + 1] = 0
                imgdata.data[index + 2] = 0
                imgdata.data[index + 3] = 255

    ctx.putImageData(imgdata,1,0)

toggle = document.getElementById('toggle')
toggle.onclick = def():
    nonlocal CHECKER_PATTERN
    CHECKER_PATTERN = not CHECKER_PATTERN
    main()
main()
