Train a Model
Every model on SOMA shares the same V1 byte-level transformer architecture. What differs is the weights. Train good weights, publish them on-chain via commit-reveal, and earn commission every time your model produces the winning embedding for a data submission.
The lifecycle is a loop: train → commit → reveal → train more → commit update → reveal update → repeat. Eventually you can automate this entirely with a Modal cron.
Train weights
Section titled “Train weights”uv run modal run --detach src/quickstart/train_torch.pyRuns on an A100 GPU. Pass --num-steps to control training length (default: 10,000).
uv run modal run --detach src/quickstart/train_flax.pyRuns on an H100 GPU. Same --num-steps flag.
Checkpoints save every 500 steps to a Modal volume (soma-training-data). Final weights are saved as model-final.safetensors.
Code walkthrough
Section titled “Code walkthrough”Image and config
Section titled “Image and config”The Modal image installs soma-models[torch], datasets, and PyTorch:
image = modal.Image.debian_slim(python_version="3.13").pip_install( "soma-models[torch]>=0.1.7", "datasets>=3.0", "torch>=2.0",)
CHECKPOINT_EVERY = 500LEARNING_RATE = 1e-4DROPOUT_RATE = 0.1MICRO_BATCH_SIZE = 2GRAD_ACCUM_STEPS = 8 # effective batch size = 2 * 8 = 16Data pipeline
Section titled “Data pipeline”Streams FineWeb from HuggingFace, tokenizes raw bytes into fixed-length sequences using the soma_models tokenizer:
def make_batches(batch_size: int): from datasets import load_dataset from soma_models.v1.configs import V1_MAX_SEQ_LEN from soma_models.v1.tokenizer import tokenize
ds = load_dataset( "HuggingFaceFW/fineweb", "sample-10BT", split="train", streaming=True )
buffer_ids, buffer_targets = [], [] for example in ds: text = example.get("text", "") if not text.strip(): continue sequences = tokenize(data=text.encode("utf-8"), max_seq_len=V1_MAX_SEQ_LEN) for seq in sequences: buffer_ids.append(seq.token_ids) buffer_targets.append(seq.targets) if len(buffer_ids) == batch_size: yield buffer_ids, buffer_targets buffer_ids, buffer_targets = [], []Training loop
Section titled “Training loop”Initializes the V1 model with SIGReg regularization, then trains with gradient accumulation (8 micro-batches per step for an effective batch size of 16):
@app.function(image=image, gpu="A100", timeout=86400, volumes={MODEL_DIR: volume})def train(num_steps: int = 10_000): import torch from soma_models.v1.torch import Model, ModelConfig, SIGReg, SIGRegConfig, compute_loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model(ModelConfig(dropout_rate=DROPOUT_RATE)).to(device) model.train() sig_reg = SIGReg(SIGRegConfig()).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
batches = make_batches(MICRO_BATCH_SIZE)
for step in range(num_steps): optimizer.zero_grad() accum_loss = 0.0
for micro in range(GRAD_ACCUM_STEPS): ids, tgts = next(batches) token_ids = torch.tensor(ids, device=device) targets = torch.tensor(tgts, device=device)
loss, embedding = compute_loss(model, sig_reg, token_ids, targets) (loss / GRAD_ACCUM_STEPS).backward() accum_loss += loss.item()
optimizer.step()
if step > 0 and step % CHECKPOINT_EVERY == 0: model.save(f"{MODEL_DIR}/checkpoint-{step}.safetensors") volume.commit()
model.save(f"{MODEL_DIR}/model-final.safetensors") volume.commit()Image and config
Section titled “Image and config”The Modal image installs soma-models[flax], datasets, optax, and JAX with CUDA:
image = modal.Image.debian_slim(python_version="3.13").pip_install( "soma-models[flax]>=0.1.7", "datasets>=3.0", "optax>=0.2", "jax[cuda12]>=0.5",)
CHECKPOINT_EVERY = 500LEARNING_RATE = 1e-4DROPOUT_RATE = 0.1MICRO_BATCH_SIZE = 2GRAD_ACCUM_STEPS = 8 # effective batch size = 2 * 8 = 16Data pipeline
Section titled “Data pipeline”Same as PyTorch — streams FineWeb, tokenizes with the soma_models tokenizer:
def make_batches(batch_size: int): from datasets import load_dataset from soma_models.v1.configs import V1_MAX_SEQ_LEN from soma_models.v1.tokenizer import tokenize
ds = load_dataset( "HuggingFaceFW/fineweb", "sample-10BT", split="train", streaming=True )
buffer_ids, buffer_targets = [], [] for example in ds: text = example.get("text", "") if not text.strip(): continue sequences = tokenize(data=text.encode("utf-8"), max_seq_len=V1_MAX_SEQ_LEN) for seq in sequences: buffer_ids.append(seq.token_ids) buffer_targets.append(seq.targets) if len(buffer_ids) == batch_size: yield buffer_ids, buffer_targets buffer_ids, buffer_targets = [], []Training loop
Section titled “Training loop”Uses Flax NNX with a JIT-compiled micro-step. Gradient accumulation is handled manually:
@app.function(image=image, gpu="H100", timeout=86400, volumes={MODEL_DIR: volume})def train(num_steps: int = 10_000): import jax import jax.numpy as jnp import optax from flax import nnx from soma_models.v1.flax import Model, ModelConfig, SIGReg, SIGRegConfig, compute_loss
rngs = nnx.Rngs(0) model = Model(ModelConfig(dropout_rate=DROPOUT_RATE), rngs) model.train() sig_reg = SIGReg(SIGRegConfig(), rngs) optimizer = nnx.Optimizer(model, optax.adam(learning_rate=LEARNING_RATE), wrt=nnx.Param)
@nnx.jit def micro_step(model, sig_reg, token_ids, targets): def loss_fn(model, sig_reg): return compute_loss(model, sig_reg, token_ids, targets) (loss, embedding), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, sig_reg) return loss, embedding, grads
batches = make_batches(MICRO_BATCH_SIZE)
for step in range(num_steps): accum_loss = jnp.zeros(()) accum_grads = None
for micro in range(GRAD_ACCUM_STEPS): ids, tgts = next(batches) loss, embedding, grads = micro_step(model, sig_reg, jnp.array(ids), jnp.array(tgts)) accum_loss = accum_loss + loss accum_grads = grads if accum_grads is None else jax.tree.map(jnp.add, accum_grads, grads)
accum_grads = jax.tree.map(lambda g: g / GRAD_ACCUM_STEPS, accum_grads) optimizer.update(model, accum_grads)
if step > 0 and step % CHECKPOINT_EVERY == 0: model.save(f"{MODEL_DIR}/checkpoint-{step}.safetensors") volume.commit()
model.save(f"{MODEL_DIR}/model-final.safetensors") volume.commit()Download weights
Section titled “Download weights”modal volume get soma-training-data model-final.safetensors .Data contract
Section titled “Data contract”SOMA models operate on raw bytes. Input data is chunked into non-overlapping sequences of max_seq_len=1024 tokens. The vocabulary is 264 tokens: 256 byte values plus special tokens.
| Token | ID |
|---|---|
| Byte values | 0–255 |
| PAD | 256 |
| EOS | 257 |
The tokenize() function handles chunking, padding, and special token insertion. compute_loss returns a tuple of (loss, embedding) — the embedding is the mean representation across all non-PAD tokens and is needed when registering your model.
The safetensors format is cross-compatible between PyTorch and Flax. You can train in one framework and load in the other.
Register on-chain
Section titled “Register on-chain”Model registration uses a two-phase commit-reveal protocol. You post a cryptographic commitment of your weights in epoch N, then reveal in epoch N+1. This prevents front-running.
import asyncioimport os
from soma_sdk import SomaClient, Keypair
WEIGHTS_PATH = "model-final.safetensors"WEIGHTS_URL = "https://your-storage.example.com/weights.enc"COMMISSION_RATE = 1000 # 10% (basis points: 100 = 1%, max 10000)
async def run(): kp = Keypair.from_secret_key(os.environ["SOMA_SECRET_KEY"]) client = await SomaClient(chain="testnet")
# 1. Encrypt weights weights = open(WEIGHTS_PATH, "rb").read() encrypted_weights, decryption_key = SomaClient.encrypt_weights(weights) print(f"Encrypted {len(weights)} bytes")
# 2. Upload encrypted weights to storage (S3, R2, etc.) # ... upload encrypted_weights to WEIGHTS_URL ...
# 3. Compute embedding (from compute_loss on held-out data) embedding = [...] # list[float] from training
# 4. Commit on-chain (epoch N) model_id = await client.commit_model( signer=kp, weights_url=WEIGHTS_URL, encrypted_weights=encrypted_weights, commission_rate=COMMISSION_RATE, ) print(f"Committed model: {model_id}")
# 5. Wait for next epoch, then reveal (epoch N+1) await client.wait_for_next_epoch()
await client.reveal_model( signer=kp, model_id=model_id, weights_url=WEIGHTS_URL, encrypted_weights=encrypted_weights, decryption_key=decryption_key, embedding=embedding, ) print(f"Revealed model: {model_id} — now active and competing")
asyncio.run(run())Encrypt weights
Section titled “Encrypt weights”Weights are encrypted with AES-256-CTR before uploading. This ensures validators can’t access your weights before the reveal phase.
encrypted_weights, decryption_key = SomaClient.encrypt_weights(weights)Upload the encrypted bytes to any accessible storage (S3, Cloudflare R2, etc.). You need a URL that validators can download from.
Commit
Section titled “Commit”The commit posts a cryptographic hash of your weights and URL. No one can see the actual weights yet.
commission_rate is in basis points: 1000 = 10%. This is the share you earn when your model produces winning submissions.
Reveal
Section titled “Reveal”The reveal publishes the weights URL, decryption key, and your model’s embedding. After this, your model is active and begins competing in target assignments.
The embedding is the list[float] returned by compute_loss during training — it determines which targets your model gets assigned to.
Update weights
Section titled “Update weights”Once your model is active, you can push updated weights to stay competitive. Updates follow the same commit-reveal pattern.
async def update(client, kp, model_id): # 1. Encrypt new weights new_weights = open("weights-v2.safetensors", "rb").read() enc_weights, dec_key = SomaClient.encrypt_weights(new_weights)
# 2. Upload to storage new_url = "https://your-storage.example.com/weights-v2.enc" # ... upload enc_weights ...
# 3. Embedding from compute_loss on held-out data new_embedding = [...]
# 4. Commit update (epoch N) await client.commit_model_update( signer=kp, model_id=model_id, weights_url=new_url, encrypted_weights=enc_weights, )
# 5. Reveal update (epoch N+1) await client.wait_for_next_epoch()
await client.reveal_model_update( signer=kp, model_id=model_id, weights_url=new_url, encrypted_weights=enc_weights, decryption_key=dec_key, embedding=new_embedding, ) print("Update revealed — new weights active")Automate the loop
Section titled “Automate the loop”Schedule the full train → commit → reveal cycle as a recurring Modal job:
import modal
app = modal.App("soma-auto-update")volume = modal.Volume.from_name("soma-training-data", create_if_missing=True)
image = ( modal.Image.debian_slim(python_version="3.13") .pip_install("soma-sdk>=0.1.7", "soma-models[torch]>=0.1.7", "datasets>=3.0", "torch>=2.0"))
@app.function( image=image, gpu="A100", volumes={"/training": volume}, secrets=[modal.Secret.from_name("soma-keypair")], timeout=86400, schedule=modal.Cron("0 */6 * * *"), # every 6 hours)async def train_and_update(): # 1. Resume from latest checkpoint # 2. Train for N more steps # 3. Save checkpoint # 4. Encrypt + upload weights # 5. commit_model_update / wait_for_next_epoch / reveal_model_update ...Deploy with modal deploy auto_update.py to activate the schedule. The wait_for_next_epoch call blocks until the epoch boundary, so commit/reveal timing is handled automatically.
Other operations
Section titled “Other operations”Set commission rate
Section titled “Set commission rate”Takes effect next epoch:
await client.set_model_commission_rate(signer, model_id, new_rate=500)Deactivate your model
Section titled “Deactivate your model”Stops competing:
await client.deactivate_model(signer, model_id)Stake on your model
Section titled “Stake on your model”await client.add_stake_to_model(signer, model_id, amount=10.0)List all models
Section titled “List all models”soma model list