mstar.model.vjepa2.components.vit_encoder#

V-JEPA 2 video encoder (ViT with 3D tubelet patches + 3D RoPE).

Ports VJEPA2PatchEmbeddings3D, VJEPA2Embeddings, VJEPA2Encoder from HuggingFace transformers/models/vjepa2/modeling_vjepa2.py.

Input layout: pixel_values_videos of shape [B, T, C, H, W] (frames before channels — matches HF default). The embedding layer permutes internally to [B, C, T, H, W] for Conv3d.

Weight layout (matches HF checkpoint keys, prefix encoder.):

encoder.embeddings.patch_embeddings.proj.{weight,bias} encoder.layer.{N}.* encoder.layernorm.{weight,bias}

Functions

apply_masks(tensor, masks)

Gather per-row patch indices from tensor.

Classes

VJEPA2Embeddings(config, hidden_size)

VJEPA2Encoder(config)

VJEPA2PatchEmbeddings3D(config, hidden_size)

class mstar.model.vjepa2.components.vit_encoder.VJEPA2Embeddings(config, hidden_size)[source]#

Bases: Module

Parameters:
forward(pixel_values_videos)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

pixel_values_videos (Tensor)

Return type:

Tensor

class mstar.model.vjepa2.components.vit_encoder.VJEPA2Encoder(config)[source]#

Bases: Module

Parameters:

config (VJepa2Config)

forward(pixel_values_videos)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

pixel_values_videos (Tensor)

Return type:

Tensor

class mstar.model.vjepa2.components.vit_encoder.VJEPA2PatchEmbeddings3D(config, hidden_size)[source]#

Bases: Module

Parameters:
forward(pixel_values_videos)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

pixel_values_videos (Tensor)

Return type:

Tensor

mstar.model.vjepa2.components.vit_encoder.apply_masks(tensor, masks)[source]#

Gather per-row patch indices from tensor.

Parameters:
  • tensor (Tensor) – [B, N, D].

  • masks (list[Tensor]) – list of [B, M] index tensors. Outputs stacked along dim 0.

Returns:

[len(masks) * B, M, D].

Return type:

Tensor