diff --git a/nle/agent/agent.py b/nle/agent/agent.py index 27576c159..02abc0d49 100644 --- a/nle/agent/agent.py +++ b/nle/agent/agent.py @@ -41,6 +41,7 @@ import nle # noqa: F401, E402 from nle.agent import vtrace # noqa: E402 +from nle import nethack # noqa: E402 # yapf: disable @@ -743,7 +744,7 @@ def __init__( self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim) - self.embed = nn.Embedding(5991, self.k_dim) + self.embed = nn.Embedding(nethack.MAX_GLYPH, self.k_dim) K = embedding_dim # number of input filters F = 3 # filter dimensions