Models Reference
Installation
Section titled “Installation”pip install soma-models
PyTorch and Flax are optional extras. Install whichever framework you need.
pip install soma-models[torch] # PyTorchpip install soma-models[flax] # Flax/JAXConfiguration
Section titled “Configuration”Five dataclass configs control model architecture. All have sensible defaults matching the V1 architecture.
ModelConfig
Section titled “ModelConfig”Top-level model configuration.
from soma_models.v1.configs import ModelConfig
config = ModelConfig(dropout_rate=0.1)| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | - | Dropout probability (required) |
embedding_dim | int | 2048 | Token embedding dimension |
pwff_hidden_dim | int | 8192 | Feed-forward hidden dimension |
num_layers | int | 24 | Number of encoder layers |
num_heads | int | 8 | Number of attention heads |
vocab_size | int | 264 | Vocabulary size (256 bytes + 8 special tokens) |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
EncoderConfig
Section titled “EncoderConfig”| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | - | Dropout probability (required) |
embedding_dim | int | 2048 | Embedding dimension |
pwff_hidden_dim | int | 8192 | Feed-forward hidden dimension |
num_layers | int | 24 | Number of layers |
num_heads | int | 8 | Attention heads |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
LayerConfig
Section titled “LayerConfig”| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | - | Dropout probability (required) |
embedding_dim | int | 2048 | Embedding dimension |
pwff_hidden_dim | int | 8192 | Feed-forward hidden dimension |
num_heads | int | 8 | Attention heads |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
PositionWiseFeedForwardConfig
Section titled “PositionWiseFeedForwardConfig”| Field | Type | Default | Description |
|---|---|---|---|
dropout_rate | float | - | Dropout probability (required) |
embedding_dim | int | 2048 | Input/output dimension |
pwff_hidden_dim | int | 8192 | Hidden dimension |
SIGRegConfig
Section titled “SIGRegConfig”Spectral Isotropy Gaussian Regularization config.
| Field | Type | Default | Description |
|---|---|---|---|
t_max | float | 3.0 | Maximum integration bound |
points | int | 17 | Number of quadrature points |
slices | int | 256 | Number of random projections |
coefficient | float | 0.02 | Regularization weight |
Constants
Section titled “Constants”| Constant | Value | Description |
|---|---|---|
V1_EMBEDDING_DIM | 2048 | Token embedding dimension |
V1_NUM_HEADS | 8 | Attention heads |
V1_NUM_LAYERS | 24 | Encoder layers |
V1_MAX_SEQ_LEN | 1024 | Maximum sequence length |
V1_VOCAB_SIZE | 264 | 256 byte values + 8 special tokens |
V1_PWFF_HIDDEN_DIM | 8192 | Feed-forward hidden dimension |
V1_MAX_WAVELENGTH | 10000.0 | RoPE max wavelength |
V1_SCALE_FACTOR | 1.0 | RoPE scale factor |
V1_PAD_TOKEN_ID | 256 | Padding token ID |
V1_EOS_TOKEN_ID | 257 | End-of-sequence token ID |
V1_SIG_REG_T_MAX | 3.0 | SIGReg integration bound |
V1_SIG_REG_SLICES | 256 | SIGReg random projections |
V1_SIG_REG_POINTS | 17 | SIGReg quadrature points |
V1_SIG_REG_COEFFICIENT | 0.02 | SIGReg weight |
V1_BATCH_SIZE | 16 | Default batch size |
ARCHITECTURE_VERSION | 1 | Must match network model_architecture_version |
Tokenizer
Section titled “Tokenizer”Framework-agnostic byte-level tokenizer matching the network’s V1 data contract.
tokenize
Section titled “tokenize”Convert raw bytes into a list of tokenized sequences.
from soma_models.v1.tokenizer import tokenize
sequences = tokenize(data, max_seq_len=1024)| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
data | bytes | bytearray | Yes | - | Raw byte data to tokenize |
max_seq_len | int | No | 1024 | Maximum sequence length per chunk |
Returns: list[TokenizedSequence]. One sequence per chunk of the input data.
TokenizedSequence
Section titled “TokenizedSequence”| Field | Type | Description |
|---|---|---|
token_ids | list[int] | [seq_len], input token IDs |
targets | list[int] | [seq_len], next-token targets (shifted left, PAD appended) |
pos_ids | list[int] | [seq_len], global byte offsets |
Model Components
Section titled “Model Components”from soma_models.v1.torch import Model, ModelConfig, SIGReg, SIGRegConfig, compute_lossInherits Serializable and nn.Module. Full transformer: embedding → encoder → final norm → predictor.
model = Model(ModelConfig(dropout_rate=0.1))| Method | Signature | Description |
|---|---|---|
encode | (input: Tensor, positions: Tensor) -> Tensor | Embed + encode, returns [batch, seq, dim] |
predict | (embeddings: Tensor) -> Tensor | Linear projection to vocab logits |
forward | (input: Tensor, positions: Tensor) -> Tensor | predict(encode(input, positions)) |
Encoder
Section titled “Encoder”Stack of Layer modules.
from soma_models.v1.torch.modules.encoder import Encoder| Method | Signature | Description |
|---|---|---|
forward | (input: Tensor, positions: Tensor) -> Tensor | Pass through all layers sequentially |
Pre-norm transformer layer: LayerNorm → Attention → residual → LayerNorm → FeedForward → residual.
from soma_models.v1.torch.modules.layer import LayerMultiHeadAttention
Section titled “MultiHeadAttention”Causal multi-head attention with Rotary Position Embeddings (RoPE).
from soma_models.v1.torch.modules.attention import MultiHeadAttention| Parameter | Type | Default | Description |
|---|---|---|---|
num_heads | int | - | Number of attention heads |
num_features | int | - | Input/output dimension |
dropout_rate | float | 0.0 | Attention dropout |
use_bias | bool | True | Use bias in projections |
max_wavelength | float | 10000.0 | RoPE max wavelength |
scale_factor | float | 1.0 | RoPE scale factor |
PositionWiseFeedForward
Section titled “PositionWiseFeedForward”Two-layer feed-forward with GELU activation.
from soma_models.v1.torch.modules.pwff import PositionWiseFeedForwardSIGReg
Section titled “SIGReg”Spectral Isotropy Gaussian Regularization. Penalizes anisotropy in representations via random Gaussian projections.
sig_reg = SIGReg(SIGRegConfig())loss = sig_reg(representations) # scalar tensor| Method | Signature | Description |
|---|---|---|
forward | (x: Tensor) -> Tensor | Compute regularization loss (generates noise internally) |
compute | (x: Tensor, noise: Tensor) -> Tensor | Compute with explicit noise tensor |
from soma_models.v1.flax import Model, ModelConfig, SIGReg, SIGRegConfig, compute_lossSame classes and interfaces as PyTorch, implemented with Flax nn.Module. All method signatures and behavior are identical.
compute_loss
Section titled “compute_loss”Compute training loss: cross-entropy (ignoring PAD tokens) plus SIGReg regularization.
from soma_models.v1.torch import compute_loss
loss, embedding = compute_loss(model, sig_reg, token_ids, targets)| Parameter | Type | Description |
|---|---|---|
model | Model | The V1 Model instance |
sig_reg | SIGReg | SIGReg module (generates noise via global RNG) |
token_ids | Tensor | Input token IDs, shape [batch, seq] |
targets | Tensor | Next-token targets, shape [batch, seq] |
Returns: (loss, embedding) where loss is a scalar tensor (CE + SIGReg) and embedding is the mean embedding of shape [embedding_dim].
Serialization
Section titled “Serialization”Model weights are stored in safetensors format. The canonical format differs from PyTorch conventions. The Serde class handles the conversion automatically.
Serializable
Section titled “Serializable”Mixin for nn.Module subclasses. The Model class inherits this.
# Savemodel.save("weights.safetensors")data = model.save_bytes()
# Loadmodel = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))model = Model.load_bytes(data, ModelConfig(dropout_rate=0.0))| Method | Signature | Description |
|---|---|---|
save | (filename) -> None | Save weights to safetensors file |
save_bytes | () -> bytes | Serialize weights to bytes |
load | (filename, *args, **kwargs) -> cls | Class method: load from file |
load_bytes | (data: bytes, *args, **kwargs) -> cls | Class method: load from bytes |
Low-level serialization. Handles LayerNorm key remapping (weight/bias ↔ gamma/beta) and linear weight transposition (row-major ↔ column-major).
from soma_models.v1.torch.serde import Serde
serde = Serde(model)data = serde.serialize()serde.deserialize(data)Utilities
Section titled “Utilities”from soma_models.utils import remap, flatten_dict, unflatten_dictRename and delete keys in a dictionary, in-place.
remap(d, rename_map={"old": "new"}, keys_to_delete=["unwanted"])flatten_dict
Section titled “flatten_dict”Flatten a nested dictionary into dotted keys.
flatten_dict({"a": {"b": 1}}) # {"a.b": 1}unflatten_dict
Section titled “unflatten_dict”Reverse of flatten_dict.
unflatten_dict({"a.b": 1}) # {"a": {"b": 1}}