-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathblack_isort_linter.py
172 lines (157 loc) · 5.2 KB
/
black_isort_linter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Format file with black and isort."""
# PyTorch LICENSE. See LICENSE file in the root directory of this source tree.
from __future__ import annotations
import argparse
import concurrent.futures
import logging
import os
import subprocess
import sys
import lintrunner_adapters
from lintrunner_adapters import LintMessage, LintSeverity, as_posix, run_command
LINTER_CODE = "BLACK-ISORT"
def check_file(
filename: str,
retries: int,
timeout: int,
*,
fast: bool = False,
) -> list[LintMessage]:
try:
with open(filename, "rb") as f:
original = f.read()
with open(filename, "rb") as f:
# Run isort first then black so we get consistent result
# even if isort is not using the black profile
proc = run_command(
[sys.executable, "-misort", "-"],
stdin=f,
retries=retries,
timeout=timeout,
check=True,
)
import_sorted = proc.stdout
# Pipe isort's result to black
proc = run_command(
[
sys.executable,
"-mblack",
*(("--pyi",) if filename.endswith(".pyi") else ()),
*(("--ipynb",) if filename.endswith(".ipynb") else ()),
*(("--fast",) if fast else ()),
"--stdin-filename",
filename,
"-",
],
stdin=None,
input=import_sorted,
retries=retries,
timeout=timeout,
check=True,
)
except subprocess.TimeoutExpired:
return [
LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="timeout",
original=None,
replacement=None,
description=(
"black-isort timed out while trying to process a file. "
"Please report an issue in pytorch/pytorch with the "
"label 'module: lint'"
),
)
]
except (OSError, subprocess.CalledProcessError) as err:
return [
LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ADVICE,
name="command-failed",
original=None,
replacement=None,
description=(
f"Failed due to {err.__class__.__name__}:\n{err}"
if not isinstance(err, subprocess.CalledProcessError)
else (
"COMMAND (exit code {returncode})\n"
"{command}\n\n"
"STDERR\n{stderr}\n\n"
"STDOUT\n{stdout}"
).format(
returncode=err.returncode,
command=" ".join(as_posix(x) for x in err.cmd),
stderr=err.stderr.decode("utf-8").strip() or "(empty)",
stdout=err.stdout.decode("utf-8").strip() or "(empty)",
)
),
)
]
replacement = proc.stdout
if original == replacement:
return []
return [
LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="format",
original=original.decode("utf-8"),
replacement=replacement.decode("utf-8"),
description="Run `lintrunner -a` to apply this patch.",
)
]
def main() -> None:
parser = argparse.ArgumentParser(
description=f"Format files with black-isort. Linter code: {LINTER_CODE}",
fromfile_prefix_chars="@",
)
parser.add_argument(
"--fast",
action="store_true",
help="If --fast given, skip temporary sanity checks.",
)
parser.add_argument(
"--timeout",
default=90,
type=int,
help="seconds to wait for black-isort",
)
lintrunner_adapters.add_default_options(parser)
args = parser.parse_args()
logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.NOTSET
if args.verbose
else logging.DEBUG
if len(args.filenames) < 1000
else logging.INFO,
stream=sys.stderr,
)
with concurrent.futures.ThreadPoolExecutor(
max_workers=os.cpu_count(),
thread_name_prefix="Thread",
) as executor:
futures = {
executor.submit(check_file, x, args.retries, args.timeout): x
for x in args.filenames
}
for future in concurrent.futures.as_completed(futures):
try:
for lint_message in future.result():
lint_message.display()
except Exception:
logging.critical('Failed at "%s".', futures[future])
raise
if __name__ == "__main__":
main()