Skip to content

Commit

Permalink
[better_errors] Port the Pallas debug info mechanisms to the new JAX …
Browse files Browse the repository at this point in the history
…DebugInfo.

Now that we carry debug informatiion in Jaxpr we can remove the Pallas-specific
tracking of the `func_src_info`, e.g., `NameAndSrcInfo`.
  • Loading branch information
gnecula committed Feb 25, 2025
1 parent 5b13883 commit c4e0db6
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 196 deletions.
7 changes: 6 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,11 @@ def unflatten_ir_values_like_types(xs: Iterable[ir.Value],

_module_name_regex = re.compile(r"[^\w.-]")

def sanitize_name(name: str) -> str:
"""Ensure a name is usable as module or function name."""
return _module_name_regex.sub("_", name)


def sharded_aval(aval: core.AbstractValue,
sharding: JSharding | AUTO | None) -> core.AbstractValue:
"""Returns the new aval sharded based on sharding proto."""
Expand Down Expand Up @@ -1211,7 +1216,7 @@ def lower_jaxpr_to_module(
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
module_name = _module_name_regex.sub("_", module_name)
module_name = sanitize_name(module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ def resolve_result_paths(self) -> DebugInfo:
def func_name(self) -> str:
return self.func_src_info.split(" ")[0]

def replace_func_name(self, name: str) -> DebugInfo:
func_src_comps = self.func_src_info.split(" ")
func_src_comps[0] = name
return self._replace(func_src_info=" ".join(func_src_comps))

def safe_arg_names(self, expected: int) -> tuple[str, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
Expand Down
61 changes: 9 additions & 52 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,48 +78,6 @@ class CompilerParams(Protocol):
class Buffered:
buffer_count: int

# TODO(necula): clean up the splitting of the fun_sourceinfo
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.
name: str
#: the source info, and the name of kernel function if not in `name`.`
src_info: str

def __str__(self):
return f"{self.name}{' ' if self.src_info else ''}{self.src_info}"
__repr__ = __str__

replace = dataclasses.replace


@staticmethod
def from_pallas_call(pallas_call_name: str | None,
src_info : str | None) -> NameAndSrcInfo:
"""Formats the name and the source info.
Args:
pallas_call_name: The `name` argument to pallas_call.
src_info: The result of `api_util.fun_source_info(kernel)`, in the form
"{function_name} at {file_name}:{line_number}".
"""
if pallas_call_name is not None:
pallas_call_name = mlir._module_name_regex.sub("_", pallas_call_name)
if src_info is None:
return NameAndSrcInfo(
"unknown" if pallas_call_name is None else pallas_call_name,
"")
if pallas_call_name is not None:
return NameAndSrcInfo(pallas_call_name,
f"for kernel function {src_info}")
src_info_parts = src_info.split(" at ")
if len(src_info_parts) > 1:
return NameAndSrcInfo(src_info_parts[0],
"at " + " ".join(src_info_parts[1:]))
else:
return NameAndSrcInfo(src_info_parts[0], "")


split_list = util.split_list

map, unsafe_map = util.safe_map, map
Expand Down Expand Up @@ -350,6 +308,8 @@ def __repr__(self):

IndexingMode = Union[Blocked, Unblocked]

def default_index_map(ndim: int) -> Callable:
return lambda *args: (0,) * ndim

@dataclasses.dataclass
class BlockSpec:
Expand All @@ -376,7 +336,8 @@ def to_block_mapping(
mapped_dims: tuple[int, ...],
) -> BlockMapping:
if self.index_map is None:
index_map_func = lambda *args: (0,) * len(array_aval.shape)
index_map_func = default_index_map(len(array_aval.shape))
api_util.save_wrapped_fun_sourceinfo(index_map_func, default_index_map)
else:
index_map_func = self.index_map
if self.block_shape is None:
Expand Down Expand Up @@ -417,33 +378,31 @@ def to_block_mapping(
index_map_func, fake_index_map_args,
fake_index_map_kwargs)
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
index_map_src_info = NameAndSrcInfo.from_pallas_call(
None, debug and debug.func_src_info
)
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
flat_index_map_fun, index_map_avals
)

mapped_block_shape = tuple(mapped if s is None else s for s in block_shape)
if len(out_avals) != len(block_shape):
raise ValueError(
f"Index map function {index_map_src_info} for "
f"Index map function {debug.func_src_info} for "
f"{origin} must return "
f"{len(block_shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
for i, ov in enumerate(out_avals):
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"Index map function {debug.func_src_info} for "
f"{origin} must return integer scalars. Output[{i}] has type "
f"{ov}."
)

if consts:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"Index map function {debug.func_src_info} for "
f"{origin} must not capture constants: {consts}"
)

Expand All @@ -453,7 +412,6 @@ def to_block_mapping(
block_shape=mapped_block_shape,
transformed_block_aval=block_aval, # There are no transforms by default
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=self.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(
array_aval_shape, array_aval.dtype
Expand Down Expand Up @@ -496,7 +454,6 @@ class BlockMapping:
block_shape: tuple[Mapped | int, ...]
transformed_block_aval: AbstractMemoryRef
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: NameAndSrcInfo
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: OriginStr
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pallas/hlo_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def pallas_call_hlo_interpret(
*args,
backend: str | None,
jaxpr: jax_core.Jaxpr,
name_and_src_info: pallas_core.NameAndStrInfo,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
Expand All @@ -345,6 +344,7 @@ def pallas_call_hlo_interpret(
out_avals: tuple[jax_core.AbstractValue, ...],
):
del compiler_params, cost_estimate, out_avals
debug_info = jaxpr.debug_info
# If we're in interpret mode, we *scan* over the grid and eval the
# discharged jaxpr.
dynamic_grid_args, args = split_list(
Expand All @@ -360,7 +360,7 @@ def pallas_call_hlo_interpret(
discharged_jaxpr, discharged_consts, scratch_avals = kernel_to_hlo_jaxpr(
jaxpr, (), grid_mapping, backend=backend)
if debug:
print(f"\nJaxpr of the the kernel in pallas_call {name_and_src_info}:")
print(f"\nJaxpr of the the kernel in pallas_call {debug_info.func_src_info}:")
print(discharged_jaxpr)
out = _initialize_output_vals(grid_mapping.block_mappings_output,
args, input_output_aliases)
Expand Down
1 change: 0 additions & 1 deletion jax/_src/pallas/mosaic/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ def get_interpret_effects():
def interpret_pallas_call(
*args,
jaxpr: jax_core.Jaxpr,
name_and_src_info: pallas_core.NameAndSrcInfo,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
Expand Down
11 changes: 5 additions & 6 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class MeshInfo:
def _check_block_mappings(
block_mappings: tuple[pallas_core.BlockMapping, ...],
lowering_context: mlir.LoweringRuleContext,
name_and_src_info: pallas_core.NameAndSrcInfo,
debug_info: jax_core.DebugInfo,
) -> None:
del lowering_context # originally needed for forward compat
for bm in block_mappings:
Expand All @@ -514,7 +514,7 @@ def _check_block_mappings(
continue

def err_details():
return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} "
return (f"Block spec for {bm.origin} in pallas_call {debug_info.func_src_info} "
"has block shape "
f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, "
# TODO(necula): add index_map source location info
Expand Down Expand Up @@ -593,7 +593,6 @@ def lower_jaxpr_to_module(
jaxpr: jax_core.Jaxpr,
*,
dimension_semantics: tuple[str | None, ...] | None,
name_and_src_info: pallas_core.NameAndSrcInfo,
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
dynamic_shape_replacement_enabled: bool = False,
Expand All @@ -603,6 +602,7 @@ def lower_jaxpr_to_module(
raise RuntimeError(
"Pallas TPU requires a libTPU version that's at most a month old"
)
debug_info = jaxpr.debug_info
if dynamic_shape_replacement_enabled:
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()

Expand All @@ -620,8 +620,7 @@ def dynamic_shape_replacement_fn(
dynamic_shape_replacement_fn = lambda x: x

# Verify that we have legal block mappings to catch errors early.
_check_block_mappings(grid_mapping.block_mappings, lowering_context,
name_and_src_info)
_check_block_mappings(grid_mapping.block_mappings, lowering_context, debug_info)

mosaic_grid_mapping = MosaicGridMapping(
jaxpr,
Expand All @@ -633,7 +632,7 @@ def dynamic_shape_replacement_fn(
mosaic_grid_mapping.maybe_compress_grid()
m = ir.Module.create()
attrs = m.operation.attributes
module_name = name_and_src_info.name
module_name = mlir.sanitize_name(debug_info.func_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
sym_tab = ir.SymbolTable(m.operation)

Expand Down
13 changes: 6 additions & 7 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext,
*in_nodes,
jaxpr: jax_core.Jaxpr,
name_and_src_info: core.NameAndSrcInfo,
grid_mapping: core.GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
Expand All @@ -118,8 +117,9 @@ def pallas_call_tpu_lowering_rule(
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
debug_info = jaxpr._debug_info
if debug:
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
print(jaxpr)
if "mosaic" in compiler_params:
mosaic_params = compiler_params["mosaic"]
Expand Down Expand Up @@ -149,13 +149,12 @@ def lower_module(for_verification: bool):
dimension_semantics=dimension_semantics,
mesh=mesh,
for_verification=for_verification,
name_and_src_info=name_and_src_info,
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
)

mosaic_module, extra_args = lower_module(for_verification=False)
if debug:
print(f"\nThe Mosaic module for pallas_call {name_and_src_info}:")
print(f"\nThe Mosaic module for pallas_call {debug_info.func_src_info}:")
print(mosaic_module)
num_extra_args = len(extra_args)
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
Expand All @@ -176,7 +175,7 @@ def lower_module(for_verification: bool):
verification_module, num_devices, num_cores
)
if promela_dump_path == "stdout":
print(f"The Promela model for pallas_call {name_and_src_info}:")
print(f"The Promela model for pallas_call {debug_info.func_src_info}:")
print(model)
else:
if promela_dump_path == "sponge":
Expand All @@ -188,7 +187,7 @@ def lower_module(for_verification: bool):
)
dump_ctx = tempfile.NamedTemporaryFile(
mode="w",
prefix=name_and_src_info.name + "-",
prefix=mlir.sanitize_name(debug_info.func_name) + "-",
suffix=".pml",
dir=promela_dump_path, delete=False,
)
Expand Down Expand Up @@ -230,7 +229,7 @@ def _maybe_cast_inputs(*args):
module=mosaic_module,
out_type=kernel_out_avals,
backend="tpu",
kernel_name=name_and_src_info.name,
kernel_name=mlir.sanitize_name(debug_info.func_name),
cost_estimate=mosaic_cost_estimate,
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
flags=mosaic_params.get("flags"),
Expand Down
14 changes: 6 additions & 8 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ def _eval_index_map(

def _check_block_mappings(
block_mappings: Sequence[pallas_core.BlockMapping],
name_and_src_info: pallas_core.NameAndSrcInfo,
debug_info: jax_core.DebugInfo,
) -> None:
def err_details(bm: pallas_core.BlockMapping) -> str:
return (
f"Block spec for {bm.origin} in pallas_call {name_and_src_info}"
f"Block spec for {bm.origin} in pallas_call {debug_info.func_src_info}"
f" has block shape {bm.block_shape}, array shape"
f" {bm.array_shape_dtype.shape},"
# TODO(necula): add index_map source location info
Expand Down Expand Up @@ -435,7 +435,6 @@ def index_map(*indices):
def lower_pipelined_jaxpr_to_module(
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
name_and_src_info: pallas_core.NameAndSrcInfo,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
) -> LoweringResult:
Expand All @@ -453,7 +452,7 @@ def lower_pipelined_jaxpr_to_module(
)

block_mappings = grid_mapping.block_mappings
_check_block_mappings(block_mappings, name_and_src_info)
_check_block_mappings(block_mappings, jaxpr.debug_info)
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
Expand Down Expand Up @@ -554,7 +553,6 @@ def body_fn(*refs):
[bm.array_shape_dtype for bm in in_block_mappings],
[bm.array_shape_dtype for bm in out_block_mappings],
new_jaxpr,
name_and_src_info,
compiler_params,
new_consts,
)
Expand All @@ -567,10 +565,10 @@ def lower_jaxpr_to_module(
in_shapes: Sequence[jax.ShapeDtypeStruct],
out_shapes: Sequence[jax.ShapeDtypeStruct],
jaxpr: jax_core.Jaxpr,
name_and_src_info: pallas_core.NameAndSrcInfo,
compiler_params: dict[str, Any],
consts=(),
) -> LoweringResult:
debug_info = jaxpr.debug_info
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
thread_semantics = params.get(
Expand All @@ -593,7 +591,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
for barrier, barrier_ref in zip(rs.barriers, runtime_barriers):
grouped_barriers[barrier].append(barrier_ref)
module_ctx = ModuleContext(
name_and_src_info.name,
mlir.sanitize_name(debug_info.func_name),
grid_names,
[_program_id(axis, squashed_dims) for axis in range(len(grid))],
approx_math,
Expand Down Expand Up @@ -632,7 +630,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8),
rs.barriers,
),
module_name=name_and_src_info.name,
module_name=mlir.sanitize_name(debug_info.func_name),
prof_spec=prof_spec,
)
)
Expand Down
Loading

0 comments on commit c4e0db6

Please sign in to comment.