Skip to content

Models Reference

pip install soma-models

PyTorch and Flax are optional extras. Install whichever framework you need.

pip install soma-models[torch] # PyTorch
pip install soma-models[flax] # Flax/JAX

Five dataclass configs control model architecture. All have sensible defaults matching the V1 architecture.

Top-level model configuration.

from soma_models.v1.configs import ModelConfig
config = ModelConfig(dropout_rate=0.1)
FieldTypeDefaultDescription
dropout_ratefloat-Dropout probability (required)
embedding_dimint2048Token embedding dimension
pwff_hidden_dimint8192Feed-forward hidden dimension
num_layersint24Number of encoder layers
num_headsint8Number of attention heads
vocab_sizeint264Vocabulary size (256 bytes + 8 special tokens)
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor
FieldTypeDefaultDescription
dropout_ratefloat-Dropout probability (required)
embedding_dimint2048Embedding dimension
pwff_hidden_dimint8192Feed-forward hidden dimension
num_layersint24Number of layers
num_headsint8Attention heads
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor
FieldTypeDefaultDescription
dropout_ratefloat-Dropout probability (required)
embedding_dimint2048Embedding dimension
pwff_hidden_dimint8192Feed-forward hidden dimension
num_headsint8Attention heads
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor
FieldTypeDefaultDescription
dropout_ratefloat-Dropout probability (required)
embedding_dimint2048Input/output dimension
pwff_hidden_dimint8192Hidden dimension

Spectral Isotropy Gaussian Regularization config.

FieldTypeDefaultDescription
t_maxfloat3.0Maximum integration bound
pointsint17Number of quadrature points
slicesint256Number of random projections
coefficientfloat0.02Regularization weight
ConstantValueDescription
V1_EMBEDDING_DIM2048Token embedding dimension
V1_NUM_HEADS8Attention heads
V1_NUM_LAYERS24Encoder layers
V1_MAX_SEQ_LEN1024Maximum sequence length
V1_VOCAB_SIZE264256 byte values + 8 special tokens
V1_PWFF_HIDDEN_DIM8192Feed-forward hidden dimension
V1_MAX_WAVELENGTH10000.0RoPE max wavelength
V1_SCALE_FACTOR1.0RoPE scale factor
V1_PAD_TOKEN_ID256Padding token ID
V1_EOS_TOKEN_ID257End-of-sequence token ID
V1_SIG_REG_T_MAX3.0SIGReg integration bound
V1_SIG_REG_SLICES256SIGReg random projections
V1_SIG_REG_POINTS17SIGReg quadrature points
V1_SIG_REG_COEFFICIENT0.02SIGReg weight
V1_BATCH_SIZE16Default batch size
ARCHITECTURE_VERSION1Must match network model_architecture_version

Framework-agnostic byte-level tokenizer matching the network’s V1 data contract.

Convert raw bytes into a list of tokenized sequences.

from soma_models.v1.tokenizer import tokenize
sequences = tokenize(data, max_seq_len=1024)
ParameterTypeRequiredDefaultDescription
databytes | bytearrayYes-Raw byte data to tokenize
max_seq_lenintNo1024Maximum sequence length per chunk

Returns: list[TokenizedSequence]. One sequence per chunk of the input data.

FieldTypeDescription
token_idslist[int][seq_len], input token IDs
targetslist[int][seq_len], next-token targets (shifted left, PAD appended)
pos_idslist[int][seq_len], global byte offsets
from soma_models.v1.torch import Model, ModelConfig, SIGReg, SIGRegConfig, compute_loss

Inherits Serializable and nn.Module. Full transformer: embedding → encoder → final norm → predictor.

model = Model(ModelConfig(dropout_rate=0.1))
MethodSignatureDescription
encode(input: Tensor, positions: Tensor) -> TensorEmbed + encode, returns [batch, seq, dim]
predict(embeddings: Tensor) -> TensorLinear projection to vocab logits
forward(input: Tensor, positions: Tensor) -> Tensorpredict(encode(input, positions))

Stack of Layer modules.

from soma_models.v1.torch.modules.encoder import Encoder
MethodSignatureDescription
forward(input: Tensor, positions: Tensor) -> TensorPass through all layers sequentially

Pre-norm transformer layer: LayerNorm → Attention → residual → LayerNorm → FeedForward → residual.

from soma_models.v1.torch.modules.layer import Layer

Causal multi-head attention with Rotary Position Embeddings (RoPE).

from soma_models.v1.torch.modules.attention import MultiHeadAttention
ParameterTypeDefaultDescription
num_headsint-Number of attention heads
num_featuresint-Input/output dimension
dropout_ratefloat0.0Attention dropout
use_biasboolTrueUse bias in projections
max_wavelengthfloat10000.0RoPE max wavelength
scale_factorfloat1.0RoPE scale factor

Two-layer feed-forward with GELU activation.

from soma_models.v1.torch.modules.pwff import PositionWiseFeedForward

Spectral Isotropy Gaussian Regularization. Penalizes anisotropy in representations via random Gaussian projections.

sig_reg = SIGReg(SIGRegConfig())
loss = sig_reg(representations) # scalar tensor
MethodSignatureDescription
forward(x: Tensor) -> TensorCompute regularization loss (generates noise internally)
compute(x: Tensor, noise: Tensor) -> TensorCompute with explicit noise tensor

Compute training loss: cross-entropy (ignoring PAD tokens) plus SIGReg regularization.

from soma_models.v1.torch import compute_loss
loss, embedding = compute_loss(model, sig_reg, token_ids, targets)
ParameterTypeDescription
modelModelThe V1 Model instance
sig_regSIGRegSIGReg module (generates noise via global RNG)
token_idsTensorInput token IDs, shape [batch, seq]
targetsTensorNext-token targets, shape [batch, seq]

Returns: (loss, embedding) where loss is a scalar tensor (CE + SIGReg) and embedding is the mean embedding of shape [embedding_dim].

Model weights are stored in safetensors format. The canonical format differs from PyTorch conventions. The Serde class handles the conversion automatically.

Mixin for nn.Module subclasses. The Model class inherits this.

# Save
model.save("weights.safetensors")
data = model.save_bytes()
# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))
model = Model.load_bytes(data, ModelConfig(dropout_rate=0.0))
MethodSignatureDescription
save(filename) -> NoneSave weights to safetensors file
save_bytes() -> bytesSerialize weights to bytes
load(filename, *args, **kwargs) -> clsClass method: load from file
load_bytes(data: bytes, *args, **kwargs) -> clsClass method: load from bytes

Low-level serialization. Handles LayerNorm key remapping (weight/biasgamma/beta) and linear weight transposition (row-major ↔ column-major).

from soma_models.v1.torch.serde import Serde
serde = Serde(model)
data = serde.serialize()
serde.deserialize(data)
from soma_models.utils import remap, flatten_dict, unflatten_dict

Rename and delete keys in a dictionary, in-place.

remap(d, rename_map={"old": "new"}, keys_to_delete=["unwanted"])

Flatten a nested dictionary into dotted keys.

flatten_dict({"a": {"b": 1}}) # {"a.b": 1}

Reverse of flatten_dict.

unflatten_dict({"a.b": 1}) # {"a": {"b": 1}}