Skip to content

Models Reference

pip install soma-models

Note that PyTorch (torch) and/or Flax (flax, jax) are optional dependencies — install whichever framework you need.

pip install soma-models torch # PyTorch
pip install soma-models flax jax # Flax/JAX

Five dataclass configs control model architecture. All have sensible defaults matching the V1 architecture.

Top-level model configuration.

from soma_models.v1.configs import ModelConfig
config = ModelConfig(dropout_rate=0.1)
FieldTypeDefaultDescription
dropout_ratefloatDropout probability (required)
embedding_dimint2048Token embedding dimension
pwff_hidden_dimint8192Feed-forward hidden dimension
num_layersint24Number of encoder layers
num_headsint8Number of attention heads
vocab_sizeint264Vocabulary size (256 bytes + 8 special tokens)
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor
FieldTypeDefaultDescription
dropout_ratefloatDropout probability (required)
embedding_dimint2048Embedding dimension
pwff_hidden_dimint8192Feed-forward hidden dimension
num_layersint24Number of layers
num_headsint8Attention heads
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor
FieldTypeDefaultDescription
dropout_ratefloatDropout probability (required)
embedding_dimint2048Embedding dimension
pwff_hidden_dimint8192Feed-forward hidden dimension
num_headsint8Attention heads
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor
FieldTypeDefaultDescription
dropout_ratefloatDropout probability (required)
embedding_dimint2048Input/output dimension
pwff_hidden_dimint8192Hidden dimension

Spectral Isotropy Gaussian Regularization config.

FieldTypeDefaultDescription
t_maxfloat3.0Maximum integration bound
pointsint17Number of quadrature points
slicesint256Number of random projections
coefficientfloat0.02Regularization weight
ConstantValueDescription
V1_EMBEDDING_DIM2048Token embedding dimension
V1_NUM_HEADS8Attention heads
V1_NUM_LAYERS24Encoder layers
V1_MAX_SEQ_LEN1024Maximum sequence length
V1_VOCAB_SIZE264256 byte values + 8 special tokens
V1_PWFF_HIDDEN_DIM8192Feed-forward hidden dimension
V1_MAX_WAVELENGTH10000.0RoPE max wavelength
V1_SCALE_FACTOR1.0RoPE scale factor
V1_PAD_TOKEN_ID256Padding token ID
V1_EOS_TOKEN_ID257End-of-sequence token ID
V1_SIG_REG_T_MAX3.0SIGReg integration bound
V1_SIG_REG_SLICES256SIGReg random projections
V1_SIG_REG_POINTS17SIGReg quadrature points
V1_SIG_REG_COEFFICIENT0.02SIGReg weight
V1_BATCH_SIZE16Default batch size
ARCHITECTURE_VERSION1Must match on-chain model_architecture_version

Framework-agnostic byte-level tokenizer matching the on-chain V1 data contract.

Convert raw bytes into batches of token IDs, targets, and position IDs.

from soma_models.v1.tokenizer import tokenize
batches = tokenize(data, max_seq_len=1024, batch_size=16)
ParameterTypeRequiredDefaultDescription
databytes | bytearrayYesRaw byte data to tokenize
max_seq_lenintNo1024Maximum sequence length per chunk
batch_sizeintNo16Sequences per batch

Returns: list[ByteSequenceBatch] — the final batch may contain fewer than batch_size sequences.

FieldTypeDescription
token_idslist[list[int]][batch, seq_len] — input token IDs
targetslist[list[int]][batch, seq_len] — next-token targets (shifted left, PAD appended)
pos_idslist[list[int]][batch, seq_len] — global byte offsets
from soma_models.v1.torch import Model, ModelConfig, SIGReg, SIGRegConfig, compute_loss

Inherits Serializable and nn.Module. Full transformer: embedding → encoder → final norm → predictor.

model = Model(ModelConfig(dropout_rate=0.1))
MethodSignatureDescription
encode(input: Tensor, positions: Tensor) -> TensorEmbed + encode, returns [batch, seq, dim]
predict(embeddings: Tensor) -> TensorLinear projection to vocab logits
forward(input: Tensor, positions: Tensor) -> Tensorpredict(encode(input, positions))

Stack of Layer modules.

from soma_models.v1.torch.modules.encoder import Encoder
MethodSignatureDescription
forward(input: Tensor, positions: Tensor) -> TensorPass through all layers sequentially

Pre-norm transformer layer: LayerNorm → Attention → residual → LayerNorm → FeedForward → residual.

from soma_models.v1.torch.modules.layer import Layer

Causal multi-head attention with Rotary Position Embeddings (RoPE).

from soma_models.v1.torch.modules.attention import MultiHeadAttention
ParameterTypeDefaultDescription
num_headsintNumber of attention heads
num_featuresintInput/output dimension
dropout_ratefloat0.0Attention dropout
use_biasboolTrueUse bias in projections
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor

Two-layer feed-forward with GELU activation.

from soma_models.v1.torch.modules.pwff import PositionWiseFeedForward

Spectral Isotropy Gaussian Regularization. Penalizes anisotropy in representations via random Gaussian projections.

sig_reg = SIGReg(SIGRegConfig())
loss = sig_reg(representations) # scalar tensor
MethodSignatureDescription
forward(x: Tensor) -> TensorCompute regularization loss (generates noise internally)
compute(x: Tensor, noise: Tensor) -> TensorCompute with explicit noise tensor

Compute training loss: cross-entropy (ignoring PAD tokens) plus SIGReg regularization.

from soma_models.v1.torch import compute_loss
loss, embedding = compute_loss(model, sig_reg, token_ids, targets)
ParameterTypeDescription
modelModelThe V1 Model instance
sig_regSIGRegSIGReg module (generates noise via global RNG)
token_idsTensorInput token IDs, shape [batch, seq]
targetsTensorNext-token targets, shape [batch, seq]

Returns: (loss, embedding) where loss is a scalar tensor (CE + SIGReg) and embedding is the mean embedding of shape [embedding_dim].

Model weights are stored in safetensors format. The canonical format differs from PyTorch conventions — the Serde class handles the conversion automatically.

Mixin for nn.Module subclasses. The Model class inherits this.

# Save
model.save("weights.safetensors")
data = model.save_bytes()
# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))
model = Model.load_bytes(data, ModelConfig(dropout_rate=0.0))
MethodSignatureDescription
save(filename) -> NoneSave weights to safetensors file
save_bytes() -> bytesSerialize weights to bytes
load(filename, *args, **kwargs) -> clsClass method — load from file
load_bytes(data: bytes, *args, **kwargs) -> clsClass method — load from bytes

Low-level serialization. Handles LayerNorm key remapping (weight/biasgamma/beta) and linear weight transposition (row-major ↔ column-major).

from soma_models.v1.torch.serde import Serde
serde = Serde(model)
data = serde.serialize()
serde.deserialize(data)
from soma_models.utils import remap, flatten_dict, unflatten_dict

Rename and delete keys in a dictionary, in-place.

remap(d, rename_map={"old": "new"}, keys_to_delete=["unwanted"])

Flatten a nested dictionary into dotted keys.

flatten_dict({"a": {"b": 1}}) # {"a.b": 1}

Reverse of flatten_dict.

unflatten_dict({"a.b": 1}) # {"a": {"b": 1}}