rainbow/raytrace_2D.py

87 lines
2.7 KiB
Python

import numpy as np
from abc import ABC, abstractmethod
class Ray:
def __init__(self, origin, direction):
self.origin = np.array(origin)
self.direction = np.array(direction) / np.linalg.norm(direction)
class Shapes(ABC):
def __init__(self,refractive_index) -> None:
self.refractive_index = refractive_index
@abstractmethod
def find_intersection(self, ray:Ray):
pass
def get_normal(self, point):
pass
class Scheme:
def __init__(self, shapes, rays) -> None:
self.shapes = shapes
self.rays = rays
self.raytrace()
def raytrace(self):
for ray in self.rays:
t, intersection_point, first_shape = None, None, None
for shape in self.shapes:
t1,intersection_point1 = shape.find_intersection(ray)
if t is None or ((t1 is not None) and (t1 < t)):
t, intersection_point, first_shape = t1, intersection_point1, shape
if t is None:
continue
else:
self.reflection_and_refraction(ray, first_shape, intersection_point)
def reflection_and_refraction(ray:Ray, shape:Shapes, intersection_point):
pass
class Disk(Shapes):
def __init__(self, center, radius, refractive_index):
super().__init__(refractive_index)
self.center = np.array(center)
self.radius = radius
def find_intersection(self, ray:Ray):
oc = ray.origin - self.center
a = np.dot(ray.direction, ray.direction)
b = 2.0 * np.dot(oc, ray.direction)
c = np.dot(oc, oc) - self.radius * self.radius
discriminant = b*b - 4*a*c
if discriminant < 0:
return None, None # 没有交点
else:
dis2 = np.dot(oc,oc)
epsilon = 1e-12
if dis2 > 1 - epsilon and dis2 < 1 + epsilon:
normal = self.get_normal(ray.origin)
if np.dot(normal, ray.direction) < 0:
t = (-b + np.sqrt(discriminant)) / (2.0 * a)
else:
return None, None
elif dis2 <= 1 - epsilon:
t = (-b + np.sqrt(discriminant)) / (2.0 * a)
else:
t = (-b - np.sqrt(discriminant)) / (2.0 * a)
if t < 0:
return None, None
intersection_point = ray.origin + t * ray.direction
intersection_point = self.center + self.radius * (intersection_point - self.center)/np.linalg.norm(intersection_point - self.center)
return t,intersection_point
def get_normal(self, point):
normal = (point - self.center)/self.radius
normal = normal / np.linalg.norm(normal)
return normal
class half_plane(Shapes):
def __init__(self, point, normal, refractive_index) -> None:
super().__init__(refractive_index)
self.point = point
self.normal = normal/np.linalg.norm(normal)
def find_intersection(self, ray:Ray):
dot = np.dot(ray.direction, self.normal)
if dot == 0:
return None,None
else:
t = np.dot(self.point - ray.origin, self.normal) / dot
return t, ray.origin + t*ray.direction
def get_normal(self, point):
return self.normal