Skip to content

Commit 3a8c190

Browse files
committed
Extract rest of the functions/classes
- verified performance is unchanged
1 parent 4baf102 commit 3a8c190

File tree

9 files changed

+235
-227
lines changed

9 files changed

+235
-227
lines changed

src/RayTracingWeekend.jl

+18-226
Original file line numberDiff line numberDiff line change
@@ -6,236 +6,28 @@ using Random
66
using RandomNumbers.Xorshifts
77
using StaticArrays
88

9-
export color_vec3_in_rgb, default_camera, get_ray, hit, near_zero, point, random_between, random_vec2,
10-
random_vec2_in_disk, random_vec3, random_vec3_in_sphere, random_vec3_on_sphere, ray_color, ray_to_HitRecord, reflect,
11-
reflectance, refract, render, reseed!, rgb, rgb_gamma2, scatter, skycolor, squared_length, trand
12-
export Camera, Dielectric, Hittable, HittableList, HitRecord, Lambertian, Material, Metal, Ray, Scatter, Sphere, Vec3
13-
export scene_2_spheres, scene_4_spheres, scene_blue_red_spheres, scene_diel_spheres, scene_random_spheres
14-
export TRNG
15-
169
include("vec.jl")
17-
18-
struct Ray{T}
19-
origin::Vec3{T} # Point
20-
dir::Vec3{T} # Vec3 # direction (unit vector)
21-
end
22-
23-
@inline point(r::Ray{T}, t::T) where T <: AbstractFloat = r.origin .+ t .* r.dir # equivalent to C++'s ray.at()
24-
25-
@inline function skycolor(ray::Ray{T}) where T
26-
white = SA[1.0, 1.0, 1.0]
27-
skyblue = SA[0.5, 0.7, 1.0]
28-
t = T(0.5)*(ray.dir.y + one(T))
29-
(one(T)-t)*white + t*skyblue
30-
end
31-
32-
# Per-thread Random Number Generator. Initialized later...
33-
const TRNG = Xoroshiro128Plus[]
34-
35-
function __init__()
36-
# Instantiate 1 RNG (Random Number Generator) per thread, for performance.
37-
# This can't be done during precompilation since the number of threads isn't known then.
38-
resize!(TRNG, Threads.nthreads())
39-
for i in 1:Threads.nthreads()
40-
TRNG[i] = Xoroshiro128Plus(i)
41-
end
42-
nothing
43-
end
44-
10+
export Vec3, squared_length, near_zero, rgb, rgb_gamma2
11+
include("init.jl")
12+
export TRNG
13+
include("structs.jl")
14+
export Ray, Hittable, HittableList, Material, HitRecord, Sphere, Scatter
4515
include("rand.jl")
46-
47-
"An object that can be hit by Ray"
48-
abstract type Hittable end
49-
50-
"""Materials tell us how rays interact with a surface"""
51-
abstract type Material{T <: AbstractFloat} end
52-
53-
"Record a hit between a ray and an object's surface"
54-
struct HitRecord{T <: AbstractFloat}
55-
t::T # distance from the ray's origin to the intersection with a surface.
56-
57-
p::Vec3{T} # point of the intersection between an object's surface and a ray
58-
n⃗::Vec3{T} # surface's outward normal vector, points towards outside of object?
59-
60-
# If true, our ray hit from outside to the front of the surface.
61-
# If false, the ray hit from within.
62-
front_face::Bool
63-
64-
mat::Material{T}
65-
66-
@inline HitRecord(t::T,p,n⃗,front_face,mat) where T = new{T}(t,p,n⃗,front_face,mat)
67-
end
68-
69-
struct Sphere{T <: AbstractFloat} <: Hittable
70-
center::Vec3{T}
71-
radius::T
72-
mat::Material{T}
73-
end
74-
75-
"""Equivalent to `hit_record.set_face_normal()`"""
76-
@inline @fastmath function ray_to_HitRecord(t::T, p, outward_n⃗, r_dir::Vec3{T}, mat::Material{T})::Union{HitRecord,Nothing} where T
77-
front_face = r_dir outward_n⃗ < 0
78-
n⃗ = front_face ? outward_n⃗ : -outward_n⃗
79-
HitRecord(t,p,n⃗,front_face,mat)
80-
end
81-
82-
struct Scatter{T<: AbstractFloat}
83-
r::Ray{T}
84-
attenuation::Vec3{T}
85-
86-
# claforte: TODO: rename to "absorbed?", i.e. not reflected/refracted?
87-
reflected::Bool # whether the scattered ray was reflected, or fully absorbed
88-
@inline Scatter(r::Ray{T},a::Vec3{T},reflected=true) where T = new{T}(r,a,reflected)
89-
end
90-
91-
"""Compute reflection vector for v (pointing to surface) and normal n⃗.
92-
93-
See [diagram](https://raytracing.github.io/books/RayTracingInOneWeekend.html#metal/mirroredlightreflection)"""
94-
@inline @fastmath reflect(v::Vec3{T}, n⃗::Vec3{T}) where T = v - (2vn⃗)*n⃗
95-
96-
@inline @fastmath function hit(s::Sphere{T}, r::Ray{T}, tmin::T, tmax::T)::Union{HitRecord,Nothing} where T
97-
oc = r.origin - s.center
98-
#a = r.dir ⋅ r.dir # unnecessary since `r.dir` is normalized
99-
a = 1
100-
half_b = oc r.dir
101-
c = ococ - s.radius^2
102-
discriminant = half_b^2 - a*c
103-
if discriminant < 0 return nothing end # no hit!
104-
sqrtd = discriminant
105-
106-
# Find the nearest root that lies in the acceptable range
107-
root = (-half_b - sqrtd) / a
108-
if root < tmin || tmax < root
109-
root = (-half_b + sqrtd) / a
110-
if root < tmin || tmax < root
111-
return nothing # no hit!
112-
end
113-
end
114-
115-
t = root
116-
p = point(r, t)
117-
n⃗ = (p - s.center) / s.radius
118-
return ray_to_HitRecord(t, p, n⃗, r.dir, s.mat)
119-
end
120-
121-
const HittableList = Vector{Hittable}
122-
123-
"""Find closest hit between `Ray r` and a list of Hittable objects `h`, within distance `tmin` < `tmax`"""
124-
@inline function hit(hittables::HittableList, r::Ray{T}, tmin::T, tmax::T)::Union{HitRecord,Nothing} where T
125-
closest = tmax # closest t so far
126-
best_rec::Union{HitRecord,Nothing} = nothing # by default, no hit
127-
@inbounds for i in eachindex(hittables) # @paulmelis reported gave him a 4X speedup?!
128-
h = hittables[i]
129-
rec = hit(h, r, tmin, closest)
130-
if rec !== nothing
131-
best_rec = rec
132-
closest = best_rec.t # i.e. ignore any further hit > this one's.
133-
end
134-
end
135-
best_rec
136-
end
137-
138-
@inline color_vec3_in_rgb(v::Vec3{T}) where T = 0.5normalize(v) + SA{T}[0.5,0.5,0.5]
139-
140-
141-
"""Compute color for a ray, recursively
142-
143-
Args:
144-
depth: how many more levels of recursive ray bounces can we still compute?"""
145-
@inline @fastmath function ray_color(r::Ray{T}, world::HittableList, depth=16) where T
146-
if depth <= 0
147-
return SA{T}[0,0,0]
148-
end
149-
150-
rec::Union{HitRecord,Nothing} = hit(world, r, T(1e-4), typemax(T))
151-
if rec !== nothing
152-
# For debugging, represent vectors as RGB:
153-
# claforte TODO: adapt to latest code!
154-
# return color_vec3_in_rgb(rec.p) # show the normalized hit point
155-
# return color_vec3_in_rgb(rec.n⃗) # show the normal in RGB
156-
# return color_vec3_in_rgb(rec.p + rec.n⃗)
157-
# return color_vec3_in_rgb(random_vec3_in_sphere())
158-
#return color_vec3_in_rgb(rec.n⃗ + random_vec3_in_sphere())
159-
160-
s = scatter(rec.mat, r, rec)
161-
if s.reflected
162-
return s.attenuation .* ray_color(s.r, world, depth-1)
163-
else
164-
return SA{T}[0,0,0]
165-
end
166-
else
167-
skycolor(r)
168-
end
169-
end
170-
16+
export reseed!, trand, random_vec3_in_sphere, random_between, random_vec3, random_vec2,
17+
random_vec3_on_sphere, random_vec2_in_disk
18+
include("hit.jl")
19+
export point, ray_to_HitRecord, hit
20+
include("ray_color.jl")
21+
export skycolor, color_vec3_in_rgb, ray_color
17122
include("camera.jl")
172-
173-
"""Render an image of `scene` using the specified camera, number of samples.
174-
175-
Args:
176-
scene: a HittableList, e.g. a list of spheres
177-
n_samples: number of samples per pixel, eq. to C++ samples_per_pixel
178-
179-
Equivalent to C++'s `main` function."""
180-
function render(scene::HittableList, cam::Camera{T}, image_width=400,
181-
n_samples=1) where T
182-
# Image
183-
aspect_ratio = T(16.0/9.0) # TODO: use cam.aspect_ratio for consistency
184-
image_height = convert(Int64, floor(image_width / aspect_ratio))
185-
186-
# Render
187-
img = zeros(RGB{T}, image_height, image_width)
188-
f32_image_width = convert(Float32, image_width)
189-
f32_image_height = convert(Float32, image_height)
190-
191-
# Reset the random seeds, so we always get the same images...
192-
# Makes comparing performance more accurate.
193-
reseed!()
194-
195-
Threads.@threads for i in 1:image_height
196-
@inbounds for j in 1:image_width # iterate over each row (FASTER?!)
197-
accum_color = SA{T}[0,0,0]
198-
u = convert(T, j/image_width)
199-
v = convert(T, (image_height-i)/image_height) # i is Y-down, v is Y-up!
200-
201-
for s in 1:n_samples
202-
if s == 1 # 1st sample is always centered
203-
δu = δv = T(0)
204-
else
205-
# Supersampling antialiasing.
206-
δu = trand(T) / f32_image_width
207-
δv = trand(T) / f32_image_height
208-
end
209-
ray = get_ray(cam, u+δu, v+δv)
210-
accum_color += ray_color(ray, scene)
211-
end
212-
img[i,j] = rgb_gamma2(accum_color / n_samples)
213-
end
214-
end
215-
img
216-
end
217-
218-
"""
219-
Args:
220-
refraction_ratio: incident refraction index divided by refraction index of
221-
hit surface. i.e. η/η′ in the figure above"""
222-
@inline @fastmath function refract(dir::Vec3{T}, n⃗::Vec3{T}, refraction_ratio::T) where T
223-
cosθ = min(-dir n⃗, one(T))
224-
r_out_perp = refraction_ratio * (dir + cosθ*n⃗)
225-
r_out_parallel = -√(abs(one(T)-squared_length(r_out_perp))) * n⃗
226-
normalize(r_out_perp + r_out_parallel)
227-
end
228-
229-
@inline @fastmath function reflectance(cosθ, refraction_ratio)
230-
# Use Schlick's approximation for reflectance.
231-
# claforte: may be buggy? I'm getting black pixels in the Hollow Glass Sphere...
232-
r0 = (1-refraction_ratio) / (1+refraction_ratio)
233-
r0 = r0^2
234-
r0 + (1-r0)*((1-cosθ)^5)
235-
end
236-
23+
export Camera, default_camera, get_ray
24+
include("render.jl")
25+
export render
26+
include("light.jl") # light transport
27+
export reflect, refract, reflectance
23728
include("material.jl")
238-
29+
export Lambertian, scatter, Metal, Dielectric
23930
include("scenes.jl")
31+
export scene_2_spheres, scene_4_spheres, scene_blue_red_spheres, scene_diel_spheres, scene_random_spheres
24032

24133
end

src/hit.jl

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Calculate intersections with surface(s)
2+
3+
@inline point(r::Ray{T}, t::T) where T <: AbstractFloat = r.origin .+ t .* r.dir # equivalent to C++'s ray.at()
4+
5+
"""Equivalent to `hit_record.set_face_normal()`"""
6+
@inline @fastmath function ray_to_HitRecord(t::T, p, outward_n⃗, r_dir::Vec3{T}, mat::Material{T})::Union{HitRecord,Nothing} where T
7+
front_face = r_dir outward_n⃗ < 0
8+
n⃗ = front_face ? outward_n⃗ : -outward_n⃗
9+
HitRecord(t,p,n⃗,front_face,mat)
10+
end
11+
12+
@inline @fastmath function hit(s::Sphere{T}, r::Ray{T}, tmin::T, tmax::T)::Union{HitRecord,Nothing} where T
13+
oc = r.origin - s.center
14+
#a = r.dir ⋅ r.dir # unnecessary since `r.dir` is normalized
15+
a = 1
16+
half_b = oc r.dir
17+
c = ococ - s.radius^2
18+
discriminant = half_b^2 - a*c
19+
if discriminant < 0 return nothing end # no hit!
20+
sqrtd = discriminant
21+
22+
# Find the nearest root that lies in the acceptable range
23+
root = (-half_b - sqrtd) / a
24+
if root < tmin || tmax < root
25+
root = (-half_b + sqrtd) / a
26+
if root < tmin || tmax < root
27+
return nothing # no hit!
28+
end
29+
end
30+
31+
t = root
32+
p = point(r, t)
33+
n⃗ = (p - s.center) / s.radius
34+
return ray_to_HitRecord(t, p, n⃗, r.dir, s.mat)
35+
end
36+
37+
"""Find closest hit between `Ray r` and a list of Hittable objects `h`, within distance `tmin` < `tmax`"""
38+
@inline function hit(hittables::HittableList, r::Ray{T}, tmin::T, tmax::T)::Union{HitRecord,Nothing} where T
39+
closest = tmax # closest t so far
40+
best_rec::Union{HitRecord,Nothing} = nothing # by default, no hit
41+
@inbounds for i in eachindex(hittables) # @paulmelis reported gave him a 4X speedup?!
42+
h = hittables[i]
43+
rec = hit(h, r, tmin, closest)
44+
if rec !== nothing
45+
best_rec = rec
46+
closest = best_rec.t # i.e. ignore any further hit > this one's.
47+
end
48+
end
49+
best_rec
50+
end
51+

src/init.jl

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Per-thread Random Number Generator. Initialized later...
2+
const TRNG = Xoroshiro128Plus[]
3+
4+
function __init__()
5+
# Instantiate 1 RNG (Random Number Generator) per thread, for performance.
6+
# This can't be done during precompilation since the number of threads isn't known then.
7+
resize!(TRNG, Threads.nthreads())
8+
for i in 1:Threads.nthreads()
9+
TRNG[i] = Xoroshiro128Plus(i)
10+
end
11+
nothing
12+
end

src/light.jl

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# light transport, e.g. reflection, refraction
2+
3+
"""Compute reflection vector for v (pointing to surface) and normal n⃗.
4+
5+
See [diagram](https://raytracing.github.io/books/RayTracingInOneWeekend.html#metal/mirroredlightreflection)"""
6+
@inline @fastmath reflect(v::Vec3{T}, n⃗::Vec3{T}) where T = v - (2vn⃗)*n⃗
7+
8+
"""
9+
Args:
10+
refraction_ratio: incident refraction index divided by refraction index of
11+
hit surface. i.e. η/η′ in the figure above"""
12+
@inline @fastmath function refract(dir::Vec3{T}, n⃗::Vec3{T}, refraction_ratio::T) where T
13+
cosθ = min(-dir n⃗, one(T))
14+
r_out_perp = refraction_ratio * (dir + cosθ*n⃗)
15+
r_out_parallel = -√(abs(one(T)-squared_length(r_out_perp))) * n⃗
16+
normalize(r_out_perp + r_out_parallel)
17+
end
18+
19+
@inline @fastmath function reflectance(cosθ, refraction_ratio)
20+
# Use Schlick's approximation for reflectance.
21+
# claforte: may be buggy? I'm getting black pixels in the Hollow Glass Sphere...
22+
r0 = (1-refraction_ratio) / (1+refraction_ratio)
23+
r0 = r0^2
24+
r0 + (1-r0)*((1-cosθ)^5)
25+
end

src/material.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Concrete materials
2+
13
struct Lambertian{T} <: Material{T}
24
albedo::Vec3{T}
35
end

src/proto/proto.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ _scene_random_spheres = scene_random_spheres(; elem_type=ELEM_TYPE)
268268
# 6.766 ms (161000 allocations: 12.40 MiB)
269269
# @inbounds and @simd in low-level functions
270270
# 6.519 ms (160609 allocations: 12.37 MiB)
271-
#render(scene_diel_spheres(; elem_type=ELEM_TYPE), t_cam2, 96, 16)
271+
render(scene_diel_spheres(; elem_type=ELEM_TYPE), t_cam2, 96, 16)
272272

273273
using Profile
274274
render(scene_random_spheres(; elem_type=ELEM_TYPE), t_cam1, 16, 1)

src/ray_color.jl

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
@inline function skycolor(ray::Ray{T}) where T
2+
white = SA[1.0, 1.0, 1.0]
3+
skyblue = SA[0.5, 0.7, 1.0]
4+
t = T(0.5)*(ray.dir.y + one(T))
5+
(one(T)-t)*white + t*skyblue
6+
end
7+
8+
@inline color_vec3_in_rgb(v::Vec3{T}) where T = 0.5normalize(v) + SA{T}[0.5,0.5,0.5]
9+
10+
"""Compute color for a ray, recursively
11+
12+
Args:
13+
depth: how many more levels of recursive ray bounces can we still compute?"""
14+
@inline @fastmath function ray_color(r::Ray{T}, world::HittableList, depth=16) where T
15+
if depth <= 0
16+
return SA{T}[0,0,0]
17+
end
18+
19+
rec::Union{HitRecord,Nothing} = hit(world, r, T(1e-4), typemax(T))
20+
if rec !== nothing
21+
# For debugging, represent vectors as RGB:
22+
# claforte TODO: adapt to latest code!
23+
# return color_vec3_in_rgb(rec.p) # show the normalized hit point
24+
# return color_vec3_in_rgb(rec.n⃗) # show the normal in RGB
25+
# return color_vec3_in_rgb(rec.p + rec.n⃗)
26+
# return color_vec3_in_rgb(random_vec3_in_sphere())
27+
#return color_vec3_in_rgb(rec.n⃗ + random_vec3_in_sphere())
28+
29+
s = scatter(rec.mat, r, rec)
30+
if s.reflected
31+
return s.attenuation .* ray_color(s.r, world, depth-1)
32+
else
33+
return SA{T}[0,0,0]
34+
end
35+
else
36+
skycolor(r)
37+
end
38+
end

0 commit comments

Comments
 (0)