Skip to content

Train a Model

Every model on SOMA shares the same V1 byte-level transformer architecture. What differs is the weights — and training good weights is what earns you commission on the network.

SOMA models operate on raw bytes. Input data is chunked into sequences of max_seq_len=1024 tokens and grouped into batches of batch_size=16. The vocabulary is 264 tokens: 256 byte values plus two special tokens.

TokenID
Byte values0–255
PAD256
EOS257

The tokenize() function handles chunking, padding, and special token insertion. You don’t need to manage this manually.

  1. Prepare your training data as raw bytes and tokenize it:

    from soma_models.v1.tokenizer import tokenize
    data: bytes = open("training_data.bin", "rb").read()
    batches = tokenize(data, max_seq_len=1024, batch_size=16)

    tokenize() returns a list[ByteSequenceBatch]. Each batch has three attributes:

    • .token_ids — input token IDs (nested list of ints)
    • .targets — shifted targets for next-token prediction
    • .pos_ids — positional indices
  2. Initialize the model and regularizer:

    from soma_models.v1.configs import ModelConfig, SIGRegConfig
    from soma_models.v1.torch import Model, SIGReg
    model = Model(ModelConfig(dropout_rate=0.1))
    sig_reg = SIGReg(SIGRegConfig())

    ModelConfig defaults match the V1 architecture: embedding_dim=2048, num_layers=24, num_heads=8, pwff_hidden_dim=8192, vocab_size=264. SIGRegConfig defaults are t_max=3.0, points=17, slices=256, coefficient=0.02.

  3. Run the training loop:

    import torch
    from soma_models.v1.torch import compute_loss
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    for batch in batches:
    token_ids = torch.tensor(batch.token_ids)
    targets = torch.tensor(batch.targets)
    loss, embedding = compute_loss(model, sig_reg, token_ids, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    compute_loss returns a tuple of (loss, embedding). The loss combines cross-entropy with SIGReg (Gaussian uniformity regularization). The embedding is the mean representation across all non-PAD tokens — you’ll need it later when registering your model.

  4. Save your trained weights:

    model.save("weights.safetensors")

    Weights are stored in safetensors format. You can also use model.save_bytes() to get the serialized weights as an in-memory bytes object.

  5. Verify by loading the weights back and checking loss on held-out data:

    loaded_model = Model.load("weights.safetensors")
    loss, embedding = compute_loss(loaded_model, sig_reg, held_out_ids, held_out_targets)
    print(f"Held-out loss: {loss}")

The safetensors format is cross-compatible between PyTorch and Flax. You can train in one framework and load in the other. This is useful if you prefer Flax for training on TPUs but want to verify with PyTorch, or vice versa.

Once you have a trained weights file and a final embedding, you’re ready to register your model on-chain.