Skip to content

Commit

Permalink
[Dy2Stat]Modify dy2stat error message in runtime and format error mes…
Browse files Browse the repository at this point in the history
…sage (#35365)
  • Loading branch information
0x45f authored Sep 3, 2021
1 parent ef7bc36 commit a6cc567
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 10 deletions.
66 changes: 57 additions & 9 deletions python/paddle/fluid/dygraph/dygraph_to_static/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sys
import traceback
import linecache
import re

from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map

Expand Down Expand Up @@ -106,22 +107,34 @@ def __init__(self, location, function_name):
begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2))

for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE):
line = linecache.getline(self.location.filepath, i)
line_lstrip = line.strip()
line = linecache.getline(self.location.filepath, i).rstrip('\n')
line_lstrip = line.lstrip()
self.source_code.append(line_lstrip)
blank_count.append(len(line) - len(line_lstrip))
if not line_lstrip: # empty line from source code
blank_count.append(-1)
else:
blank_count.append(len(line) - len(line_lstrip))

if i == self.location.lineno:
hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE'
self.source_code.append(hint_msg)
blank_count.append(blank_count[-1])
linecache.clearcache()

min_black_count = min(blank_count)
# remove top and bottom empty line in source code
while len(self.source_code) > 0 and not self.source_code[0]:
self.source_code.pop(0)
blank_count.pop(0)
while len(self.source_code) > 0 and not self.source_code[-1]:
self.source_code.pop(-1)
blank_count.pop(-1)

min_black_count = min([i for i in blank_count if i >= 0])
for i in range(len(self.source_code)):
self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]
# if source_code[i] is empty line between two code line, dont add blank
if self.source_code[i]:
self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]

def formated_message(self):
msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format(
Expand Down Expand Up @@ -212,16 +225,51 @@ def _simplify_error_value(self):
1. Need a more robust way because the code of start_trace may change.
2. Set the switch to determine whether to simplify error_value
"""

assert self.in_runtime is True

error_value_lines = str(self.error_value).split("\n")
error_value_lines_strip = [mes.lstrip(" ") for mes in error_value_lines]

start_trace = "outputs = static_func(*inputs)"
start_idx = error_value_lines_strip.index(start_trace)

error_value_lines = error_value_lines[start_idx + 1:]
error_value_lines_strip = error_value_lines_strip[start_idx + 1:]

# use empty line to locate the bottom_error_message
empty_line_idx = error_value_lines_strip.index('')
bottom_error_message = error_value_lines[empty_line_idx + 1:]

filepath = ''
error_from_user_code = []
pattern = 'File "(?P<filepath>.+)", line (?P<lineno>.+), in (?P<function_name>.+)'
for i in range(0, len(error_value_lines_strip), 2):
if error_value_lines_strip[i].startswith("File "):
re_result = re.search(pattern, error_value_lines_strip[i])
tmp_filepath, lineno_str, function_name = re_result.groups()
code = error_value_lines_strip[i + 1] if i + 1 < len(
error_value_lines_strip) else ''
if i == 0:
filepath = tmp_filepath
if tmp_filepath == filepath:
error_from_user_code.append(
(tmp_filepath, int(lineno_str), function_name, code))

error_frame = []
whether_source_range = True
for filepath, lineno, funcname, code in error_from_user_code[::-1]:
loc = Location(filepath, lineno)
if whether_source_range:
traceback_frame = TraceBackFrameRange(loc, funcname)
whether_source_range = False
else:
traceback_frame = TraceBackFrame(loc, funcname, code)

error_frame.insert(0, traceback_frame.formated_message())

error_value_str = '\n'.join(error_value_lines)
error_frame.extend(bottom_error_message)
error_value_str = '\n'.join(error_frame)
self.error_value = self.error_type(error_value_str)

def raise_new_exception(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def test_func(self):
return


@paddle.jit.to_static
def func_error_in_runtime_with_empty_line(x):
x = fluid.dygraph.to_variable(x)
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")

x = fluid.layers.reshape(x, shape=[1, two])

return x


class TestFlags(unittest.TestCase):
def setUp(self):
self.reset_flags_to_default()
Expand Down Expand Up @@ -293,7 +303,26 @@ def set_message(self):
self.expected_message = \
[
'File "{}", line 54, in func_error_in_runtime'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, two])'
'x = fluid.dygraph.to_variable(x)',
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])',
'<--- HERE',
'return x'
]


class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime):
def set_func(self):
self.func = func_error_in_runtime_with_empty_line

def set_message(self):
self.expected_message = \
[
'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format(self.filepath),
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])',
'<--- HERE',
'return x'
]


Expand Down

0 comments on commit a6cc567

Please sign in to comment.