import numpy as np from abc import ABC, abstractmethod class Ray: def __init__(self, origin, direction): self.origin = np.array(origin) self.direction = np.array(direction) / np.linalg.norm(direction) class Shapes(ABC): def __init__(self,refractive_index) -> None: self.refractive_index = refractive_index @abstractmethod def find_intersection(self, ray:Ray): pass def get_normal(self, point): pass class Scheme: def __init__(self, shapes, rays) -> None: self.shapes = shapes self.rays = rays self.raytrace() def raytrace(self): for ray in self.rays: t, intersection_point, first_shape = None, None, None for shape in self.shapes: t1,intersection_point1 = shape.find_intersection(ray) if t is None or ((t1 is not None) and (t1 < t)): t, intersection_point, first_shape = t1, intersection_point1, shape if t is None: continue else: self.reflection_and_refraction(ray, first_shape, intersection_point) def reflection_and_refraction(ray:Ray, shape:Shapes, intersection_point): pass class Disk(Shapes): def __init__(self, center, radius, refractive_index): super().__init__(refractive_index) self.center = np.array(center) self.radius = radius def find_intersection(self, ray:Ray): oc = ray.origin - self.center a = np.dot(ray.direction, ray.direction) b = 2.0 * np.dot(oc, ray.direction) c = np.dot(oc, oc) - self.radius * self.radius discriminant = b*b - 4*a*c if discriminant < 0: return None, None # 没有交点 else: if np.dot(oc, oc) <= self.radius * self.radius + 1e-10: t = (-b + np.sqrt(discriminant)) / (2.0 * a) else: t = (-b - np.sqrt(discriminant)) / (2.0 * a) if t < 0: return None, None intersection_point = ray.origin + t * ray.direction intersection_point = intersection_point/np.linalg.norm(intersection_point) return t,intersection_point def get_normal(self, point): normal = (point - self.center)/self.radius normal = normal / np.linalg.norm(normal) return normal class half_plane(Shapes): def __init__(self, point, normal, refractive_index) -> None: super().__init__(refractive_index) self.point = point self.normal = normal/np.linalg.norm(normal) def find_intersection(self, ray:Ray): dot = np.dot(ray.direction, self.normal) if dot == 0: return None,None else: t = np.dot(self.point - ray.origin, self.normal) / dot return t, ray.origin + t*ray.direction def get_normal(self, point): return self.normal