mstar.utils.adarms_norm

mstar.utils.adarms_norm#

Fused AdaRMS normalisation + scale/shift/gate Triton kernel.

Replaces the three-step sequence in Pi05AdaRMSNorm.forward():

  1. _rms_normalize(x) – two passes over x (cast, square, mean, rsqrt, mul)

  2. modulation.chunk(3, dim=-1) – slices already in registers, free

  3. normed * (1+scale) + shift – two more passes over x-sized tensors

with a single pass:

  • Load x row → compute RMS in float32 → normalise

  • Load (scale, shift, gate) row from modulation → apply conditioning

  • Store normed output and gate

Falls back to the original eager path on CPU or when Triton is not available.

Functions

adarms_norm_fused(x, scale, shift, gate_mod)

Fused AdaRMS norm: RMS-normalise x then apply scale/shift conditioning.

mstar.utils.adarms_norm.adarms_norm_fused(x, scale, shift, gate_mod, eps=1e-6)[source]#

Fused AdaRMS norm: RMS-normalise x then apply scale/shift conditioning.

Parameters:
  • x (Tensor) – float tensor, shape [BS * AH, H], any dtype.

  • scale (Tensor) – shape [BS, H] — the (1 + scale) multiplier.

  • shift (Tensor) – shape [BS, H] — additive shift after norm.

  • gate_mod (Tensor) – shape [BS, H] — gate vector returned unchanged.

  • eps (float) – variance epsilon for numerical stability.

Returns:

(normed, gate) both shape [BS * AH, H], same dtype as x.

Return type:

tuple[Tensor, Tensor]

Falls back to an eager implementation on CPU or without Triton.