Models Reference
Installation
Section titled “Installation”pip install soma-models
Note that PyTorch (torch) and/or Flax (flax, jax) are optional dependencies — install whichever framework you need.
pip install soma-models torch # PyTorchpip install soma-models flax jax # Flax/JAXConfiguration
Section titled “Configuration”Five dataclass configs control model architecture. All have sensible defaults matching the V1 architecture.
ModelConfig
Section titled “ModelConfig”Top-level model configuration.
from soma_models.v1.configs import ModelConfig
config = ModelConfig(dropout_rate=0.1)| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | — | Dropout probability (required) |
embedding_dim | int | 2048 | Token embedding dimension |
pwff_hidden_dim | int | 8192 | Feed-forward hidden dimension |
num_layers | int | 24 | Number of encoder layers |
num_heads | int | 8 | Number of attention heads |
vocab_size | int | 264 | Vocabulary size (256 bytes + 8 special tokens) |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
EncoderConfig
Section titled “EncoderConfig”| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | — | Dropout probability (required) |
embedding_dim | int | 2048 | Embedding dimension |
pwff_hidden_dim | int | 8192 | Feed-forward hidden dimension |
num_layers | int | 24 | Number of layers |
num_heads | int | 8 | Attention heads |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
LayerConfig
Section titled “LayerConfig”| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | — | Dropout probability (required) |
embedding_dim | int | 2048 | Embedding dimension |
pwff_hidden_dim | int | 8192 | Feed-forward hidden dimension |
num_heads | int | 8 | Attention heads |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
PositionWiseFeedForwardConfig
Section titled “PositionWiseFeedForwardConfig”| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | — | Dropout probability (required) |
embedding_dim | int | 2048 | Input/output dimension |
pwff_hidden_dim | int | 8192 | Hidden dimension |
SIGRegConfig
Section titled “SIGRegConfig”Spectral Isotropy Gaussian Regularization config.
| Field | Type | Default | Description |
|---|---|---|---|
t_max | float | 3.0 | Maximum integration bound |
points | int | 17 | Number of quadrature points |
slices | int | 256 | Number of random projections |
coefficient | float | 0.02 | Regularization weight |
Constants
Section titled “Constants”| Constant | Value | Description |
|---|---|---|
V1_EMBEDDING_DIM | 2048 | Token embedding dimension |
V1_NUM_HEADS | 8 | Attention heads |
V1_NUM_LAYERS | 24 | Encoder layers |
V1_MAX_SEQ_LEN | 1024 | Maximum sequence length |
V1_VOCAB_SIZE | 264 | 256 byte values + 8 special tokens |
V1_PWFF_HIDDEN_DIM | 8192 | Feed-forward hidden dimension |
V1_MAX_WAVELENGTH | 10000.0 | RoPE max wavelength |
V1_SCALE_FACTOR | 1.0 | RoPE scale factor |
V1_PAD_TOKEN_ID | 256 | Padding token ID |
V1_EOS_TOKEN_ID | 257 | End-of-sequence token ID |
V1_SIG_REG_T_MAX | 3.0 | SIGReg integration bound |
V1_SIG_REG_SLICES | 256 | SIGReg random projections |
V1_SIG_REG_POINTS | 17 | SIGReg quadrature points |
V1_SIG_REG_COEFFICIENT | 0.02 | SIGReg weight |
V1_BATCH_SIZE | 16 | Default batch size |
ARCHITECTURE_VERSION | 1 | Must match on-chain model_architecture_version |
Tokenizer
Section titled “Tokenizer”Framework-agnostic byte-level tokenizer matching the on-chain V1 data contract.
tokenize
Section titled “tokenize”Convert raw bytes into batches of token IDs, targets, and position IDs.
from soma_models.v1.tokenizer import tokenize
batches = tokenize(data, max_seq_len=1024, batch_size=16)| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
data | bytes | bytearray | Yes | — | Raw byte data to tokenize |
max_seq_len | int | No | 1024 | Maximum sequence length per chunk |
batch_size | int | No | 16 | Sequences per batch |
Returns: list[ByteSequenceBatch] — the final batch may contain fewer than batch_size sequences.
ByteSequenceBatch
Section titled “ByteSequenceBatch”| Field | Type | Description |
|---|---|---|
token_ids | list[list[int]] | [batch, seq_len] — input token IDs |
targets | list[list[int]] | [batch, seq_len] — next-token targets (shifted left, PAD appended) |
pos_ids | list[list[int]] | [batch, seq_len] — global byte offsets |
Model Components
Section titled “Model Components”from soma_models.v1.torch import Model, ModelConfig, SIGReg, SIGRegConfig, compute_lossInherits Serializable and nn.Module. Full transformer: embedding → encoder → final norm → predictor.
model = Model(ModelConfig(dropout_rate=0.1))| Method | Signature | Description |
|---|---|---|
encode | (input: Tensor, positions: Tensor) -> Tensor | Embed + encode, returns [batch, seq, dim] |
predict | (embeddings: Tensor) -> Tensor | Linear projection to vocab logits |
forward | (input: Tensor, positions: Tensor) -> Tensor | predict(encode(input, positions)) |
Encoder
Section titled “Encoder”Stack of Layer modules.
from soma_models.v1.torch.modules.encoder import Encoder| Method | Signature | Description |
|---|---|---|
forward | (input: Tensor, positions: Tensor) -> Tensor | Pass through all layers sequentially |
Pre-norm transformer layer: LayerNorm → Attention → residual → LayerNorm → FeedForward → residual.
from soma_models.v1.torch.modules.layer import LayerMultiHeadAttention
Section titled “MultiHeadAttention”Causal multi-head attention with Rotary Position Embeddings (RoPE).
from soma_models.v1.torch.modules.attention import MultiHeadAttention| Parameter | Type | Default | Description |
|---|---|---|---|
num_heads | int | — | Number of attention heads |
num_features | int | — | Input/output dimension |
dropout_rate | float | 0.0 | Attention dropout |
use_bias | bool | True | Use bias in projections |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
PositionWiseFeedForward
Section titled “PositionWiseFeedForward”Two-layer feed-forward with GELU activation.
from soma_models.v1.torch.modules.pwff import PositionWiseFeedForwardSIGReg
Section titled “SIGReg”Spectral Isotropy Gaussian Regularization. Penalizes anisotropy in representations via random Gaussian projections.
sig_reg = SIGReg(SIGRegConfig())loss = sig_reg(representations) # scalar tensor| Method | Signature | Description |
|---|---|---|
forward | (x: Tensor) -> Tensor | Compute regularization loss (generates noise internally) |
compute | (x: Tensor, noise: Tensor) -> Tensor | Compute with explicit noise tensor |
from soma_models.v1.flax import Model, ModelConfig, SIGReg, SIGRegConfig, compute_lossSame classes and interfaces as PyTorch, implemented with Flax nn.Module. All method signatures and behavior are identical.
compute_loss
Section titled “compute_loss”Compute training loss: cross-entropy (ignoring PAD tokens) plus SIGReg regularization.
from soma_models.v1.torch import compute_loss
loss, embedding = compute_loss(model, sig_reg, token_ids, targets)| Parameter | Type | Description |
|---|---|---|
model | Model | The V1 Model instance |
sig_reg | SIGReg | SIGReg module (generates noise via global RNG) |
token_ids | Tensor | Input token IDs, shape [batch, seq] |
targets | Tensor | Next-token targets, shape [batch, seq] |
Returns: (loss, embedding) where loss is a scalar tensor (CE + SIGReg) and embedding is the mean embedding of shape [embedding_dim].
Serialization
Section titled “Serialization”Model weights are stored in safetensors format. The canonical format differs from PyTorch conventions — the Serde class handles the conversion automatically.
Serializable
Section titled “Serializable”Mixin for nn.Module subclasses. The Model class inherits this.
# Savemodel.save("weights.safetensors")data = model.save_bytes()
# Loadmodel = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))model = Model.load_bytes(data, ModelConfig(dropout_rate=0.0))| Method | Signature | Description |
|---|---|---|
save | (filename) -> None | Save weights to safetensors file |
save_bytes | () -> bytes | Serialize weights to bytes |
load | (filename, *args, **kwargs) -> cls | Class method — load from file |
load_bytes | (data: bytes, *args, **kwargs) -> cls | Class method — load from bytes |
Low-level serialization. Handles LayerNorm key remapping (weight/bias ↔ gamma/beta) and linear weight transposition (row-major ↔ column-major).
from soma_models.v1.torch.serde import Serde
serde = Serde(model)data = serde.serialize()serde.deserialize(data)Utilities
Section titled “Utilities”from soma_models.utils import remap, flatten_dict, unflatten_dictRename and delete keys in a dictionary, in-place.
remap(d, rename_map={"old": "new"}, keys_to_delete=["unwanted"])flatten_dict
Section titled “flatten_dict”Flatten a nested dictionary into dotted keys.
flatten_dict({"a": {"b": 1}}) # {"a.b": 1}unflatten_dict
Section titled “unflatten_dict”Reverse of flatten_dict.
unflatten_dict({"a.b": 1}) # {"a": {"b": 1}}