Skip to content

Commit

Permalink
Merge branch 'MegaIng-bytes-support'
Browse files Browse the repository at this point in the history
  • Loading branch information
erezsh committed Jul 31, 2020
2 parents b2d1761 + 7c6e94b commit 9ee8428
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Useful for caching and multiprocessing.
- **priority** - How priorities should be evaluated - auto, none, normal, invert (Default: auto)
- **lexer_callbacks** - Dictionary of callbacks for the lexer. May alter tokens during lexing. Use with caution.
- **edit_terminals** - A callback
- **use_bytes** - Accept and parse an input of type `bytes` instead of `str`. Grammar should still be specified as `str`, and terminal values are assumed to be `latin-1`.


#### Using Unicode character classes with `regex`
Expand Down
5 changes: 4 additions & 1 deletion lark-stubs/lark.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class LarkOptions:
lexer_callbacks: Dict[str, Callable[[Token], Token]]
cache: Union[bool, str]
g_regex_flags: int
use_bytes: bool


class Lark:
source: str
grammar_source: str
options: LarkOptions
lexer: Lexer
terminals: List[TerminalDef]
Expand All @@ -56,7 +58,8 @@ class Lark:
maybe_placeholders: bool = False,
lexer_callbacks: Optional[Dict[str, Callable[[Token], Token]]] = None,
cache: Union[bool, str] = False,
g_regex_flags: int = ...
g_regex_flags: int = ...,
use_bytes: bool = False,
):
...

Expand Down
5 changes: 3 additions & 2 deletions lark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
###{standalone

class LexerConf(Serialize):
__serialize_fields__ = 'tokens', 'ignore', 'g_regex_flags'
__serialize_fields__ = 'tokens', 'ignore', 'g_regex_flags', 'use_bytes'
__serialize_namespace__ = TerminalDef,

def __init__(self, tokens, re_module, ignore=(), postlex=None, callbacks=None, g_regex_flags=0, skip_validation=False):
def __init__(self, tokens, re_module, ignore=(), postlex=None, callbacks=None, g_regex_flags=0, skip_validation=False, use_bytes=False):
self.tokens = tokens # TODO should be terminals
self.ignore = ignore
self.postlex = postlex
self.callbacks = callbacks or {}
self.g_regex_flags = g_regex_flags
self.re_module = re_module
self.skip_validation = skip_validation
self.use_bytes = use_bytes

def _deserialize(self):
self.callbacks = {} # TODO
Expand Down
17 changes: 13 additions & 4 deletions lark/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ def get_context(self, text, span=40):
pos = self.pos_in_stream
start = max(pos - span, 0)
end = pos + span
before = text[start:pos].rsplit('\n', 1)[-1]
after = text[pos:end].split('\n', 1)[0]
return before + after + '\n' + ' ' * len(before) + '^\n'
if not isinstance(text, bytes):
before = text[start:pos].rsplit('\n', 1)[-1]
after = text[pos:end].split('\n', 1)[0]
return before + after + '\n' + ' ' * len(before) + '^\n'
else:
before = text[start:pos].rsplit(b'\n', 1)[-1]
after = text[pos:end].split(b'\n', 1)[0]
return (before + after + b'\n' + b' ' * len(before) + b'^\n').decode("ascii", "backslashreplace")

def match_examples(self, parse_fn, examples, token_type_match_fallback=False):
""" Given a parser instance and a dictionary mapping some label with
Expand Down Expand Up @@ -67,7 +72,11 @@ def match_examples(self, parse_fn, examples, token_type_match_fallback=False):

class UnexpectedCharacters(LexError, UnexpectedInput):
def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None, token_history=None):
message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column)

if isinstance(seq, bytes):
message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos:lex_pos+1].decode("ascii", "backslashreplace"), line, column)
else:
message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column)

self.line = line
self.column = column
Expand Down
19 changes: 14 additions & 5 deletions lark/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from io import open


from .utils import STRING_TYPE, Serialize, SerializeMemoizer, FS
from .utils import STRING_TYPE, Serialize, SerializeMemoizer, FS, isascii
from .load_grammar import load_grammar
from .tree import Tree
from .common import LexerConf, ParserConf
Expand Down Expand Up @@ -82,6 +82,7 @@ class LarkOptions(Serialize):
invert (Default: auto)
lexer_callbacks - Dictionary of callbacks for the lexer. May alter
tokens during lexing. Use with caution.
use_bytes - Accept an input of type `bytes` instead of `str` (Python 3 only).
edit_terminals - A callback
"""
if __doc__:
Expand All @@ -105,6 +106,7 @@ class LarkOptions(Serialize):
'maybe_placeholders': False,
'edit_terminals': None,
'g_regex_flags': 0,
'use_bytes': False,
}

def __init__(self, options_dict):
Expand All @@ -114,7 +116,7 @@ def __init__(self, options_dict):
for name, default in self._defaults.items():
if name in o:
value = o.pop(name)
if isinstance(default, bool) and name != 'cache':
if isinstance(default, bool) and name not in ('cache', 'use_bytes'):
value = bool(value)
else:
value = default
Expand Down Expand Up @@ -187,6 +189,13 @@ def __init__(self, grammar, **options):
grammar = read()

assert isinstance(grammar, STRING_TYPE)
self.grammar_source = grammar
if self.options.use_bytes:
if not isascii(grammar):
raise ValueError("Grammar must be ascii only, when use_bytes=True")
if sys.version_info[0] == 2 and self.options.use_bytes != 'force':
raise NotImplementedError("`use_bytes=True` may have issues on python2."
"Use `use_bytes='force'` to use it at your own risk.")

cache_fn = None
if self.options.cache:
Expand All @@ -196,7 +205,7 @@ def __init__(self, grammar, **options):
cache_fn = self.options.cache
else:
if self.options.cache is not True:
raise ValueError("cache must be bool or str")
raise ValueError("cache argument must be bool or str")
unhashable = ('transformer', 'postlex', 'lexer_callbacks', 'edit_terminals')
from . import __version__
options_str = ''.join(k+str(v) for k, v in options.items() if k not in unhashable)
Expand Down Expand Up @@ -252,7 +261,7 @@ def __init__(self, grammar, **options):
for t in self.terminals:
self.options.edit_terminals(t)

self._terminals_dict = {t.name:t for t in self.terminals}
self._terminals_dict = {t.name: t for t in self.terminals}

# If the user asked to invert the priorities, negate them all here.
# This replaces the old 'resolve__antiscore_sum' option.
Expand All @@ -276,7 +285,7 @@ def __init__(self, grammar, **options):
if hasattr(t, term.name):
lexer_callbacks[term.name] = getattr(t, term.name)

self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags)
self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags, use_bytes=self.options.use_bytes)

if self.options.parser:
self.parser = self._build_parser()
Expand Down
31 changes: 18 additions & 13 deletions lark/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def __eq__(self, other):


class LineCounter:
def __init__(self):
self.newline_char = '\n'
def __init__(self, newline_char):
self.newline_char = newline_char
self.char_pos = 0
self.line = 1
self.column = 1
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self, lexer, state=None):
def lex(self, stream, newline_types, ignore_types):
newline_types = frozenset(newline_types)
ignore_types = frozenset(ignore_types)
line_ctr = LineCounter()
line_ctr = LineCounter('\n' if not self.lexer.use_bytes else b'\n')
last_token = None

while line_ctr.char_pos < len(stream):
Expand Down Expand Up @@ -230,7 +230,7 @@ def __call__(self, t):



def _create_unless(terminals, g_regex_flags, re_):
def _create_unless(terminals, g_regex_flags, re_, use_bytes):
tokens_by_type = classify(terminals, lambda t: type(t.pattern))
assert len(tokens_by_type) <= 2, tokens_by_type.keys()
embedded_strs = set()
Expand All @@ -247,31 +247,34 @@ def _create_unless(terminals, g_regex_flags, re_):
if strtok.pattern.flags <= retok.pattern.flags:
embedded_strs.add(strtok)
if unless:
callback[retok.name] = UnlessCallback(build_mres(unless, g_regex_flags, re_, match_whole=True))
callback[retok.name] = UnlessCallback(build_mres(unless, g_regex_flags, re_, match_whole=True, use_bytes=use_bytes))

terminals = [t for t in terminals if t not in embedded_strs]
return terminals, callback


def _build_mres(terminals, max_size, g_regex_flags, match_whole, re_):
def _build_mres(terminals, max_size, g_regex_flags, match_whole, re_, use_bytes):
# Python sets an unreasonable group limit (currently 100) in its re module
# Worse, the only way to know we reached it is by catching an AssertionError!
# This function recursively tries less and less groups until it's successful.
postfix = '$' if match_whole else ''
mres = []
while terminals:
pattern = u'|'.join(u'(?P<%s>%s)' % (t.name, t.pattern.to_regexp() + postfix) for t in terminals[:max_size])
if use_bytes:
pattern = pattern.encode('latin-1')
try:
mre = re_.compile(u'|'.join(u'(?P<%s>%s)'%(t.name, t.pattern.to_regexp()+postfix) for t in terminals[:max_size]), g_regex_flags)
mre = re_.compile(pattern, g_regex_flags)
except AssertionError: # Yes, this is what Python provides us.. :/
return _build_mres(terminals, max_size//2, g_regex_flags, match_whole, re_)
return _build_mres(terminals, max_size//2, g_regex_flags, match_whole, re_, use_bytes)

# terms_from_name = {t.name: t for t in terminals[:max_size]}
mres.append((mre, {i:n for n,i in mre.groupindex.items()} ))
terminals = terminals[max_size:]
return mres

def build_mres(terminals, g_regex_flags, re_, match_whole=False):
return _build_mres(terminals, len(terminals), g_regex_flags, match_whole, re_)
def build_mres(terminals, g_regex_flags, re_, use_bytes, match_whole=False):
return _build_mres(terminals, len(terminals), g_regex_flags, match_whole, re_, use_bytes)

def _regexp_has_newline(r):
r"""Expressions that may indicate newlines in a regexp:
Expand Down Expand Up @@ -321,12 +324,13 @@ def __init__(self, conf):
self.terminals = terminals
self.user_callbacks = conf.callbacks
self.g_regex_flags = conf.g_regex_flags
self.use_bytes = conf.use_bytes

self._mres = None
# self.build(g_regex_flags)

def _build(self):
terminals, self.callback = _create_unless(self.terminals, self.g_regex_flags, re_=self.re)
terminals, self.callback = _create_unless(self.terminals, self.g_regex_flags, re_=self.re, use_bytes=self.use_bytes)
assert all(self.callback.values())

for type_, f in self.user_callbacks.items():
Expand All @@ -336,7 +340,7 @@ def _build(self):
else:
self.callback[type_] = f

self._mres = build_mres(terminals, self.g_regex_flags, self.re)
self._mres = build_mres(terminals, self.g_regex_flags, self.re, self.use_bytes)

@property
def mres(self):
Expand Down Expand Up @@ -365,7 +369,8 @@ def __init__(self, conf, states, always_accept=()):
assert t.name not in tokens_by_name, t
tokens_by_name[t.name] = t

trad_conf = type(conf)(terminals, conf.re_module, conf.ignore, callbacks=conf.callbacks, g_regex_flags=conf.g_regex_flags, skip_validation=conf.skip_validation)
trad_conf = copy(conf)
trad_conf.tokens = terminals

lexer_by_tokens = {}
self.lexers = {}
Expand Down
2 changes: 2 additions & 0 deletions lark/parser_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def _prepare_match(self, lexer_conf):
else:
if width == 0:
raise ValueError("Dynamic Earley doesn't allow zero-width regexps", t)
if lexer_conf.use_bytes:
regexp = regexp.encode('utf-8')

self.regexps[t.name] = lexer_conf.re_module.compile(regexp, lexer_conf.g_regex_flags)

Expand Down
15 changes: 14 additions & 1 deletion lark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,17 @@ def combine_alternatives(lists):

class FS:
open = open
exists = os.path.exists
exists = os.path.exists



def isascii(s):
""" str.isascii only exists in python3.7+ """
try:
return s.isascii()
except AttributeError:
try:
s.encode('ascii')
return True
except (UnicodeDecodeError, UnicodeEncodeError):
return False
Loading

0 comments on commit 9ee8428

Please sign in to comment.