Coding assert shape

The codebase follows a consistent pattern of asserting tensor shapes at every stage of computation.

Core rules

  1. Define expected shape upfront. Before any computation, derive the expected shape from the lattice shape, batch dims, and other known sizes:

    lat_shape = u.shape[:4]
    n_fields = u.shape[-1]
    per_field_shape = (*lat_shape, 3, 3, n_fields)
    full_out_shape = (*lat_shape, 4, 3, 3, m)
    
  2. Always assert the whole shape, not slices of it. Compare the full .shape tuple against the full expected tuple. Prefer assert u.shape == full_shape over assert u.shape[4:] == (4, 3, 3) — a slice assertion can miss a silently changed lattice dimension.

  3. Only fix existing intermediate assertions. Do not add new assertions for intermediate results; only maintain and fix shape assertions that already exist in the code.

  4. Assert function input shapes at entry. Validate input parameters and arguments before any computation — use lat_shape, n_fields, or hard-coded constants:

    expected_u_shape = (*lat_shape, 4, 3, 3)
    assert u.shape == expected_u_shape, f"func: u.shape={u.shape} != {expected_u_shape}"
    assert params["conv_kernel"].shape == (n_offsets, 24, 4), ...
    assert target.shape == lat_shape, ...
    
  5. Assert final output shape before return. The last thing before returning is a shape check on the result:

    result = jnp.stack(dir_fields, axis=4)
    assert result.shape == full_out_shape, ...
    return result
    
  6. Use descriptive error messages of the form f"func_name: var.shape={var.shape} != expected".

Rationale

  • Catches JIT shape errors at compile time rather than silently producing wrong results.

  • Makes the expected tensor contract self-documenting — the expected shape is written right next to the computation.

  • JAX’s JIT will dead-code eliminate these asserts in compiled code, so there is no runtime overhead in production.