Skip to content

Commit 5fb0267

Browse files
committed
unsuccesful attempt to speed up using Float32
- somehow there are more allocations - visually the results appear identical though
1 parent 87f4f54 commit 5fb0267

File tree

1 file changed

+73
-56
lines changed

1 file changed

+73
-56
lines changed

src/proto.jl

+73-56
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
using Pkg
55
Pkg.activate(@__DIR__)
66

7-
MyFloat = Float64 # Float32
7+
MyFloat = Float32 # Float64
88

99

1010
using BenchmarkTools, Images, InteractiveUtils, LinearAlgebra, StaticArrays
@@ -13,7 +13,7 @@ Vec3 = SVector{3, MyFloat}
1313
Vec2 = SVector{2, MyFloat}
1414
Point = Vec3
1515
Color = Vec3
16-
t_col = Color(0.4, 0.5, 0.1) # test color
16+
t_col = Color(0.4f0, 0.5f0, 0.1f0) # test color
1717

1818
# claforte: This was meant to be a convenient function to get some_vec.x or some_color.r,
1919
# but this causes ~41 allocations per call, so this become a huge bottleneck.
@@ -91,10 +91,10 @@ end
9191

9292
gradient(200,100)
9393

94-
rgb(v::Vec3) = RGB(v...)
94+
rgb(v::Vec3) = RGB{MyFloat}(v...)
9595
rgb(t_col)
9696

97-
rgb_gamma2(v::Vec3) = RGB(sqrt.(v)...)
97+
rgb_gamma2(v::Vec3) = RGB{MyFloat}(sqrt.(v)...)
9898

9999
rgb_gamma2(t_col)
100100

@@ -111,24 +111,30 @@ end
111111
# r.origin .+ t .* r.dir
112112
# end
113113

114-
function point(r::Ray, t::Float64)::Point # point at parameter t
114+
function point(r::Ray, t::MyFloat)::Point # point at parameter t
115115
r.origin .+ t .* r.dir
116116
end
117117

118-
119118
#md"# Chapter 4: Rays, simple camera, and background"
120119

121-
function sky_color(ray::Ray)
122-
# NOTE: unlike in the C++ implementation, we normalize the ray direction.
123-
t = 0.5(ray.dir[2] + 1.0)
120+
typeof(MyFloat(1.0))
121+
typeof(1.0f0)
122+
123+
function sky_color(ray::Ray)::Color
124+
# NOTE: unlike in the C++ implementation, the ray direction is normalized beforehand.
125+
t = 0.5f0(ray.dir[2] + 1.0f0)
124126
#t = 0.5(ray.dir.y + 1.0)
125-
(1-t)*Color(1,1,1) + t*Color(0.5, 0.7, 1.0)
127+
res = (1.0f0-t)*Color(1.0f0,1.0f0,1.0f0) + t*Color(0.5f0, 0.7f0, 1.0f0)
128+
res
126129
end
127130

128131
# interpolates between blue and white
129-
rgb(Color(0.5, 0.7, 1.0)), rgb(Color(1.0, 1.0, 1.0))
132+
p_zero = Point(0.0f0,0.0f0,0.0f0)
133+
v3_minusY = Vec3(0.0f0,-1.0f0,0.0f0)
130134

131-
rgb(sky_color(Ray(Point(0,0,0), Vec3(0,-1,0))))
135+
136+
rgb(Color(0.5, 0.7, 1.0)), rgb(Color(1.0, 1.0, 1.0))
137+
rgb(sky_color(Ray(Point(0.0,0.0,0.0), Vec3(0.0,-1.0,0.0))))
132138

133139
# before optimization: @btime rgb(sky_color(Ray(Point(0,0,0), Vec3(0,-1,0))))
134140
# 1.164 μs (19 allocations: 560 bytes)
@@ -138,21 +144,28 @@ rgb(sky_color(Ray(Point(0,0,0), Vec3(0,-1,0))))
138144
# 509.135 ns (7 allocations: 224 bytes)
139145
# test after forcing Float64 instead of AbstractFloat in Ray()...
140146
# 308.289 ns (4 allocations: 128 bytes)
141-
@btime Ray(Point(0,0,0), Vec3(0,-1,0))
147+
# test after forcing every Float64 to Float32:
148+
# 333.710 ns (4 allocations: 128 bytes)
149+
# test after forcing every input to Float32: `@btime Ray(Point(0.0f0,0.0f0,0.0f0), Vec3(0.0f0,-1.0f0,0.0f0))`
150+
# 294.919 ns (4 allocations: 128 bytes)
151+
# test allocate outside, i.e.: `@btime Ray(p_zero, v3_minusY)`
152+
# 9.858 ns (0 allocations: 0 bytes)
153+
#
154+
@btime Ray(p_zero, v3_minusY)
142155

143156
# md"""# Random vectors
144157
# C++'s section 8.1"""
145158

146-
random_between(min=0.0, max=1.0) = rand()*(max-min) + min # equiv to random_double()
147-
#random_between(50, 100)
159+
random_between(min=0.0f0, max=1.0f0) = rand()*(max-min) + min # equiv to random_double()
160+
@btime random_between(50, 100) # 4.399 ns (0 allocations: 0 bytes)
148161

149162
#[random_between(50.0, 100.0) for i in 1:3]
150163

151-
function random_vec3(min=0.0, max=1.0)
164+
function random_vec3(min=0.0f0, max=1.0f0)
152165
Vec3([random_between(min, max) for i in 1:3]...)
153166
end
154167

155-
random_vec3(-1,1)
168+
@btime random_vec3(-1.0f0,1.0f0)
156169

157170
function random_vec3_in_sphere() # equiv to random_in_unit_sphere()
158171
while (true)
@@ -167,7 +180,9 @@ squared_length(random_vec3_in_sphere())
167180

168181
"Random unit vector. Equivalent to C++'s `unit_vector(random_in_unit_sphere())`"
169182
random_vec3_on_sphere() = normalize(random_vec3_in_sphere())
170-
random_vec3_on_sphere()
183+
184+
# TO OPTIMIZE!
185+
@btime random_vec3_on_sphere() # 517.538 ns (12 allocations: 418 bytes)... but random
171186
norm(random_vec3_on_sphere())
172187

173188
function random_vec2_in_disk() :: Vec2 # equiv to random_in_unit_disk()
@@ -191,7 +206,7 @@ function main(nx::Int, ny::Int, scene)
191206
vertical = Vec3(0, 2, 0)
192207
origin = Point(0, 0, 0)
193208

194-
img = zeros(RGB, ny, nx)
209+
img = zeros(RGB{MyFloat}, ny, nx)
195210
for i in 1:ny, j in 1:nx # Julia is column-major, i.e. iterate 1 column at a time
196211
u = j/nx
197212
v = (ny-i)/ny # Y-up!
@@ -204,7 +219,7 @@ function main(nx::Int, ny::Int, scene)
204219
img
205220
end
206221

207-
main(200,100, sky_color)
222+
main(200,100, sky_color) # TO OPTIMIZE! 28.630 ms (520010 allocations: 13.05 MiB)
208223

209224
#md"# Chapter 5: Add a sphere"
210225

@@ -243,11 +258,11 @@ function hit_sphere2(center::Vec3, radius::AbstractFloat, r::Ray)
243258
end
244259

245260
function sphere_scene2(r::Ray)
246-
sphere_center = Vec3(0,0,-1)
247-
t = hit_sphere2(sphere_center, 0.5, r) # sphere of radius 0.5 centered at z=-1
248-
if t > 0
261+
sphere_center = Vec3(0f0,0f0,-1f0)
262+
t = hit_sphere2(sphere_center, 0.5f0, r) # sphere of radius 0.5 centered at z=-1
263+
if t > 0f0
249264
n⃗ = normalize(point(r, t) - sphere_center) # normal vector. typed n\vec
250-
return 0.5n⃗ + Vec3(0.5,0.5,0.5) # remap normal to rgb
265+
return 0.5f0n⃗ + Vec3(0.5f0,0.5f0,0.5f0) # remap normal to rgb
251266
else
252267
sky_color(r)
253268
end
@@ -265,7 +280,7 @@ abstract type Material end
265280
mutable struct HitRecord
266281
# claforte: Not sure if this needs to be mutable... might impact performance!
267282

268-
t::Float64 # vector from the ray's origin to the intersection with a surface
283+
t::MyFloat # vector from the ray's origin to the intersection with a surface
269284
p::Vec3 # point of the intersection between an object's surface and a ray
270285
n⃗::Vec3 # surface's outward normal vector, points towards outside of object?
271286

@@ -277,7 +292,7 @@ end
277292

278293
struct Sphere <: Hittable
279294
center::Vec3
280-
radius::Float64
295+
radius::MyFloat
281296
mat::Material
282297
end
283298

@@ -331,7 +346,7 @@ function scatter(mat::Lambertian, r::Ray, rec::HitRecord)::Scatter
331346
return Scatter(scattered_r, attenuation)
332347
end
333348

334-
function hit(s::Sphere, r::Ray, tmin::Float64, tmax::Float64)::Option{HitRecord}
349+
function hit(s::Sphere, r::Ray, tmin::MyFloat, tmax::MyFloat)::Option{HitRecord}
335350
oc = r.origin - s.center
336351
a = 1 #r.dir ⋅ r.dir # normalized vector - always 1
337352
half_b = oc r.dir
@@ -360,8 +375,8 @@ struct HittableList <: Hittable
360375
end
361376

362377
#"""Find closest hit between `Ray r` and a list of Hittable objects `h`, within distance `tmin` < `tmax`"""
363-
function hit(hittables::HittableList, r::Ray, tmin::Float64,
364-
tmax::Float64)::Option{HitRecord}
378+
function hit(hittables::HittableList, r::Ray, tmin::MyFloat,
379+
tmax::MyFloat)::Option{HitRecord}
365380
closest = tmax # closest t so far
366381
rec = missing
367382
for h in hittables.list
@@ -380,7 +395,7 @@ color_vec3_in_rgb(v::Vec3) = 0.5normalize(v) + Vec3(0.5,0.5,0.5)
380395

381396
mutable struct Metal<:Material
382397
albedo::Color
383-
fuzz::Float64 # how big the sphere used to generate fuzzy reflection rays. 0=none
398+
fuzz::MyFloat # how big the sphere used to generate fuzzy reflection rays. 0=none
384399
Metal(a,f=0.0) = new(a,f)
385400
end
386401

@@ -427,7 +442,7 @@ mutable struct Camera
427442
u::Vec3
428443
v::Vec3
429444
w::Vec3
430-
lens_radius::Float64
445+
lens_radius::MyFloat
431446
end
432447

433448
"""
@@ -436,10 +451,10 @@ end
436451
aspect_ratio: horizontal/vertical ratio of pixels
437452
aperture: if 0 - no depth-of-field
438453
"""
439-
function default_camera(lookfrom::Point=Point(0,0,0), lookat::Point=Point(0,0,-1),
440-
vup::Vec3=Vec3(0,1,0), vfov=90.0, aspect_ratio=16.0/9.0,
441-
aperture=0.0, focus_dist=1.0)
442-
viewport_height = 2.0 * tand(vfov/2)
454+
function default_camera(lookfrom::Point=Point(0f0,0f0,0f0), lookat::Point=Point(0f0,0f0,-1f0),
455+
vup::Vec3=Vec3(0f0,1f0,0f0), vfov=90.0f0, aspect_ratio=16.0f0/9.0f0,
456+
aperture=0.0f0, focus_dist=1.0f0)
457+
viewport_height = 2.0f0 * tand(vfov/2f0)
443458
viewport_width = aspect_ratio * viewport_height
444459

445460
w = normalize(lookfrom - lookat)
@@ -449,8 +464,8 @@ function default_camera(lookfrom::Point=Point(0,0,0), lookat::Point=Point(0,0,-1
449464
origin = lookfrom
450465
horizontal = focus_dist * viewport_width * u
451466
vertical = focus_dist * viewport_height * v
452-
lower_left_corner = origin - horizontal/2 - vertical/2 - focus_dist*w
453-
lens_radius = aperture/2
467+
lower_left_corner = origin - horizontal/2f0 - vertical/2f0 - focus_dist*w
468+
lens_radius = aperture/2f0
454469
Camera(origin, lower_left_corner, horizontal, vertical, u, v, w, lens_radius)
455470
end
456471

@@ -460,14 +475,14 @@ clamp(3.5, 0, 1)
460475

461476
#md"# Render
462477

463-
function get_ray(c::Camera, s::Float64, t::Float64)
478+
function get_ray(c::Camera, s::MyFloat, t::MyFloat)
464479
rd = Vec2(c.lens_radius * random_vec2_in_disk())
465480
offset = c.u * rd[1] + c.v * rd[2] #offset = c.u * rd.x + c.v * rd.y
466481
Ray(c.origin + offset, normalize(c.lower_left_corner + s*c.horizontal +
467482
t*c.vertical - c.origin - offset))
468483
end
469484

470-
get_ray(default_camera(), 0.0, 0.0)
485+
get_ray(default_camera(), 0.0f0, 0.0f0)
471486

472487
"""Compute color for a ray, recursively
473488
@@ -478,7 +493,7 @@ function ray_color(r::Ray, world::HittableList, depth=4)::Vec3
478493
return Vec3(0,0,0)
479494
end
480495

481-
rec = hit(world, r, 1e-4, Inf)
496+
rec = hit(world, r, 1f-4, Inf32)
482497
if !ismissing(rec)
483498
# For debugging, represent vectors as RGB:
484499
# return color_vec3_in_rgb(rec.p) # show the normalized hit point
@@ -514,25 +529,25 @@ end
514529
function render(scene::HittableList, cam::Camera, image_width=400,
515530
n_samples=1)
516531
# Image
517-
aspect_ratio = 16.0/9.0 # TODO: use cam.aspect_ratio for consistency
532+
aspect_ratio = 16.0f0/9.0f0 # TODO: use cam.aspect_ratio for consistency
518533
image_height = convert(Int64, floor(image_width / aspect_ratio))
519534

520535
# Render
521-
img = zeros(RGB, image_height, image_width)
536+
img = zeros(RGB{MyFloat}, image_height, image_width)
522537
# Compared to C++, Julia is:
523538
# 1. column-major, i.e. iterate 1 column at a time, so invert i,j compared to C++
524539
# 2. 1-based, so no need to subtract 1 from image_width, etc.
525540
# 3. The array is Y-down, but `v` is Y-up
526541
for i in 1:image_height, j in 1:image_width
527-
accum_color = Vec3(0,0,0)
542+
accum_color = Vec3(0f0,0f0,0f0)
528543
for s in 1:n_samples
529-
u = j/image_width
530-
v = (image_height-i)/image_height # i is Y-down, v is Y-up!
544+
u = MyFloat(j/image_width)
545+
v = MyFloat((image_height-i)/image_height) # i is Y-down, v is Y-up!
531546
if s != 1 # 1st sample is always centered, for 1-sample/pixel
532547
# claforte: I think the C++ version had a bug, the rand offset was
533548
# between [0,1] instead of centered at 0, e.g. [-0.5, 0.5].
534-
u += (rand()-0.5) / image_width
535-
v += (rand()-0.5) / image_height
549+
u += MyFloat((rand()-0.5f0) / image_width)
550+
v += MyFloat((rand()-0.5f0) / image_height)
536551
end
537552
ray = get_ray(cam, u, v)
538553
accum_color += ray_color(ray, scene)
@@ -559,26 +574,26 @@ render(scene_4_spheres(), default_camera(), 96, 16)
559574
# Args:
560575
# refraction_ratio: incident refraction index divided by refraction index of
561576
# hit surface. i.e. η/η′ in the figure above"""
562-
function refract(dir::Vec3, n⃗::Vec3, refraction_ratio::Float64)
577+
function refract(dir::Vec3, n⃗::Vec3, refraction_ratio::MyFloat)
563578
cosθ = min(-dir n⃗, 1)
564579
r_out_perp = refraction_ratio * (dir + cosθ*n⃗)
565580
r_out_parallel = -√(abs(1-squared_length(r_out_perp))) * n⃗
566581
normalize(r_out_perp + r_out_parallel)
567582
end
568583

569-
@assert refract(Vec3(0.6,-0.8,0), Vec3(0,1,0), 1.0) == Vec3(0.6,-0.8,0) # unchanged
584+
@assert refract(Vec3(0.6,-0.8,0), Vec3(0,1,0), 1.0f0) == Vec3(0.6,-0.8,0) # unchanged
570585

571-
t_refract_widerθ = refract(Vec3(0.6,-0.8,0), Vec3(0,1,0), 2.0) # wider angle
586+
t_refract_widerθ = refract(Vec3(0.6,-0.8,0), Vec3(0,1,0), 2.0f0) # wider angle
572587
@assert isapprox(t_refract_widerθ, Vec3(0.87519, -0.483779, 0.0); atol=1e-3)
573588

574-
t_refract_narrowerθ = refract(Vec3(0.6,-0.8,0), Vec3(0,1,0), 0.5) # narrower angle
589+
t_refract_narrowerθ = refract(Vec3(0.6,-0.8,0), Vec3(0,1,0), 0.5f0) # narrower angle
575590
@assert isapprox(t_refract_narrowerθ, Vec3(0.3, -0.953939, 0.0); atol=1e-3)
576591

577592
mutable struct Dielectric <: Material
578-
ir::Float64 # index of refraction, i.e. η.
593+
ir::MyFloat # index of refraction, i.e. η.
579594
end
580595

581-
function reflectance(cosθ::Float64, refraction_ratio::Float64)
596+
function reflectance(cosθ::MyFloat, refraction_ratio::MyFloat)
582597
# Use Schlick's approximation for reflectance.
583598
# claforte: may be buggy? I'm getting black pixels in the Hollow Glass Sphere...
584599
r0 = (1-refraction_ratio) / (1+refraction_ratio)
@@ -588,9 +603,9 @@ end
588603

589604
function scatter(mat::Dielectric, r_in::Ray, rec::HitRecord)
590605
attenuation = Color(1,1,1)
591-
refraction_ratio = rec.front_face ? (1.0/mat.ir) : mat.ir # i.e. ηᵢ/ηₜ
592-
cosθ = min(-r_in.dirrec.n⃗, 1.0)
593-
sinθ = (1.0 - cosθ^2)
606+
refraction_ratio = rec.front_face ? (1.0f0/mat.ir) : mat.ir # i.e. ηᵢ/ηₜ
607+
cosθ = min(-r_in.dirrec.n⃗, 1.0f0)
608+
sinθ = (1.0f0 - cosθ^2)
594609
cannot_refract = refraction_ratio * sinθ > 1.0
595610
if cannot_refract || reflectance(cosθ, refraction_ratio) > rand()
596611
dir = reflect(r_in.dir, rec.n⃗)
@@ -705,5 +720,7 @@ t_cam = default_camera(t_lookfrom, t_lookat, Vec3(0.0,1.0,0.0), 20.0, 16.0/9.0,
705720
# 1.001 s ( 17406437 allocations: 425.87 MiB)
706721
# after forcing Ray and point() to use Float64 instead of AbstractFloat:
707722
# 397.905 ms (6269207 allocations: 201.30 MiB)
723+
# after forcing use of Float32 instead of Float64:
724+
# 487.680 ms (7128113 allocations: 196.89 MiB) # More allocations... something is causing them...
708725
@btime render(scene_diel_spheres(), t_cam, 96, 16)
709726

0 commit comments

Comments
 (0)