mstar.model.vjepa2.components.layers#
Shared transformer building blocks for V-JEPA 2 encoder and predictor.
Ports VJEPA2MLP, VJEPA2RopeAttention, VJEPA2Layer from HuggingFace
transformers/models/vjepa2/modeling_vjepa2.py. Uses eager (matmul+softmax)
attention — no SDPA auto-selection or gradient checkpointing — to keep
numerics bit-reproducible against the reference implementation.
- Weight layout per layer (matches HF checkpoint keys):
norm1.{weight,bias} attention.{query,key,value}.{weight,bias} attention.proj.{weight,bias} norm2.{weight,bias} mlp.fc1.{weight,bias} mlp.fc2.{weight,bias}
Classes
|
One transformer block: pre-norm self-attention + pre-norm MLP with residuals. |
|
|
|
Self-attention with 3D rotary positional encoding. |
- class mstar.model.vjepa2.components.layers.VJEPA2Layer(config, hidden_size, num_attention_heads, mlp_ratio)[source]#
Bases:
ModuleOne transformer block: pre-norm self-attention + pre-norm MLP with residuals.
- Parameters:
config (VJepa2Config)
hidden_size (int)
num_attention_heads (int)
mlp_ratio (float)
- forward(hidden_states, position_mask=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mstar.model.vjepa2.components.layers.VJEPA2MLP(config, hidden_size, mlp_ratio)[source]#
Bases:
Module- Parameters:
config (VJepa2Config)
hidden_size (int)
mlp_ratio (float)
- forward(hidden_state)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mstar.model.vjepa2.components.layers.VJEPA2RopeAttention(config, hidden_size, num_attention_heads)[source]#
Bases:
ModuleSelf-attention with 3D rotary positional encoding.
Q/K/V are separate
nn.Linearprojections (matches HF checkpoint key layout:attention.{query,key,value,proj}.*). RoPE is applied to queries and keys, split into depth/height/width axes derived from each token’s position id.- Parameters:
config (VJepa2Config)
hidden_size (int)
num_attention_heads (int)
- forward(hidden_states, position_mask=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.