-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreceding_horizon.jl
66 lines (53 loc) · 1.75 KB
/
receding_horizon.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Cache to speed-up policy evaluation (needs to go over all states).
const Cache{S,A} = Dict{Tuple{S,Int},Tuple{Float64,A}}
struct RecedingHorizonPolicy{S,A} <: Policy
mdp::MDP{S,A}
horizon::Int
cache::Union{Cache{S,A}, Nothing}
end
# NOTE: shared_cache is not thread-safe!
function RecedingHorizonPolicy(mdp::MDP{S,A}, horizon::Int; shared_cache = false) where {S,A}
cache = shared_cache ? Cache{S,A}() : nothing
RecedingHorizonPolicy(mdp, horizon, cache)
end
function action(policy::RecedingHorizonPolicy, state)
value, action = if !isnothing(policy.cache)
receding_horizon(policy.mdp, state, policy.horizon, policy.cache)
else
receding_horizon(policy.mdp, state, policy.horizon)
end
action
end
# From https://hal.laas.fr/hal-02413636/document -- Algorithm 1
# TODO: Cleanup type annotations, find why type inference fails
# TODO: Non-recursive version (DFS)
function receding_horizon(
mdp::MDP{S,A},
state::S,
horizon::Int,
cache = Cache{S,A}(),
) where {S,A}
@argcheck horizon >= 0
if horizon == 0
return 0.0, first(actions(mdp))::A
end
key = (state, horizon)
if haskey(cache, key)
return cache[key]
end
# (max, argmax)
best::Tuple{Float64,A} = (-Inf, first(actions(mdp)))
discount_factor = discount(mdp)
for action::A in actions(mdp)
dist = transition(mdp, state, action)
util::Float64 = 0.0
for (statep::S, proba::Float64) in weighted_iterator(dist)
r = reward(mdp, state, action, statep)
v, _ = receding_horizon(mdp, statep, horizon - 1, cache)
util += proba * (r + discount_factor * v)
end
best = max(best, (util, action))
end
cache[key] = best
best
end