Train a Model
Every model on SOMA shares the same V1 byte-level transformer architecture. What differs is the weights — and training good weights is what earns you commission on the network.
Data Contract
Section titled “Data Contract”SOMA models operate on raw bytes. Input data is chunked into sequences of max_seq_len=1024 tokens and grouped into batches of batch_size=16. The vocabulary is 264 tokens: 256 byte values plus two special tokens.
| Token | ID |
|---|---|
| Byte values | 0–255 |
| PAD | 256 |
| EOS | 257 |
The tokenize() function handles chunking, padding, and special token insertion. You don’t need to manage this manually.
Tokenize Your Data
Section titled “Tokenize Your Data”-
Prepare your training data as raw bytes and tokenize it:
from soma_models.v1.tokenizer import tokenizedata: bytes = open("training_data.bin", "rb").read()batches = tokenize(data, max_seq_len=1024, batch_size=16)tokenize()returns alist[ByteSequenceBatch]. Each batch has three attributes:.token_ids— input token IDs (nested list of ints).targets— shifted targets for next-token prediction.pos_ids— positional indices
-
Initialize the model and regularizer:
from soma_models.v1.configs import ModelConfig, SIGRegConfigfrom soma_models.v1.torch import Model, SIGRegmodel = Model(ModelConfig(dropout_rate=0.1))sig_reg = SIGReg(SIGRegConfig())from flax import nnxfrom soma_models.v1.configs import ModelConfig, SIGRegConfigfrom soma_models.v1.flax import Model, SIGRegmodel = Model(ModelConfig(dropout_rate=0.1), rngs=nnx.Rngs(0))sig_reg = SIGReg(SIGRegConfig(), rngs=nnx.Rngs(1))ModelConfigdefaults match the V1 architecture:embedding_dim=2048,num_layers=24,num_heads=8,pwff_hidden_dim=8192,vocab_size=264.SIGRegConfigdefaults aret_max=3.0,points=17,slices=256,coefficient=0.02. -
Run the training loop:
import torchfrom soma_models.v1.torch import compute_lossoptimizer = torch.optim.Adam(model.parameters(), lr=1e-4)for batch in batches:token_ids = torch.tensor(batch.token_ids)targets = torch.tensor(batch.targets)loss, embedding = compute_loss(model, sig_reg, token_ids, targets)optimizer.zero_grad()loss.backward()optimizer.step()import jaximport jax.numpy as jnpimport optaxfrom soma_models.v1.flax import compute_lossoptimizer = optax.adam(1e-4)opt_state = optimizer.init(nnx.state(model))for batch in batches:token_ids = jnp.array(batch.token_ids)targets = jnp.array(batch.targets)grad_fn = jax.grad(lambda m, sr: compute_loss(m, sr, token_ids, targets)[0])grads = grad_fn(model, sig_reg)updates, opt_state = optimizer.update(grads, opt_state)nnx.update(model, updates)_, embedding = compute_loss(model, sig_reg, token_ids, targets)compute_lossreturns a tuple of(loss, embedding). The loss combines cross-entropy with SIGReg (Gaussian uniformity regularization). The embedding is the mean representation across all non-PAD tokens — you’ll need it later when registering your model. -
Save your trained weights:
model.save("weights.safetensors")Weights are stored in safetensors format. You can also use
model.save_bytes()to get the serialized weights as an in-memory bytes object. -
Verify by loading the weights back and checking loss on held-out data:
loaded_model = Model.load("weights.safetensors")loss, embedding = compute_loss(loaded_model, sig_reg, held_out_ids, held_out_targets)print(f"Held-out loss: {loss}")
Cross-Framework Compatibility
Section titled “Cross-Framework Compatibility”The safetensors format is cross-compatible between PyTorch and Flax. You can train in one framework and load in the other. This is useful if you prefer Flax for training on TPUs but want to verify with PyTorch, or vice versa.
Next Steps
Section titled “Next Steps”Once you have a trained weights file and a final embedding, you’re ready to register your model on-chain.