# Model Strategies

Your model competes on weights and embedding placement. This guide covers strategies for improving both: learning from competitors, distillation, and positioning your embedding.

> Source: [`training.py`](https://github.com/soma-org/quickstart/blob/main/src/quickstart/training.py)
**Infrastructure:** These guides use [Modal](https://modal.com) for convenience. For sustainable long-term operation, consider more affordable GPU providers — Lambda Labs, vast.ai, RunPod, and others offer significantly lower per-hour rates. Or consider setting up your own hardware.

## Learn from the network

Every data submission is publicly accessible via its data URL. Every revealed model's weights are downloadable. Use both to improve your model.

### Download competitor weights

```python
from soma_sdk import SomaClient

client = await SomaClient(chain="testnet")

model_bytes = await client.fetch_model(model_id="0xABC...")
```

Load into your framework:

```python
        from soma_models.v1.torch import Model
        from soma_models.v1.configs import ModelConfig

        competitor = Model.load_bytes(
            model_bytes,
            ModelConfig(dropout_rate=0.0),
        )
        ```
    ```python
        from soma_models.v1.flax import Model
        from soma_models.v1.configs import ModelConfig
        from flax import nnx

        competitor = Model.load_bytes(
            model_bytes,
            ModelConfig(dropout_rate=0.0),
            rngs=nnx.Rngs(0),
        )
        ```
    ### Download submitted data

Fetch data from recently filled targets. This is data that scored well against at least one model:

```python
targets = await client.get_targets(status="filled", limit=50)

training_data = []
for target in targets:
    data = await client.fetch_submission_data(target.id)
    training_data.append(data)
```

Training on this data biases your model toward domains the network is actively exploring.

## Distilling from competitors

The most direct way to compete is to learn from what's already working. There are several approaches, from simple to sophisticated.

### Fine-tune from competitor weights

Initialize your model from a strong competitor's checkpoint instead of random weights, then continue training with your own data:

```python
# Load competitor as starting point
competitor = Model.load_bytes(model_bytes, ModelConfig(dropout_rate=DROPOUT_RATE))
competitor = competitor.to(device)
competitor.train()

# Train as usual — the model starts from a better position
sig_reg = SIGReg(SIGRegConfig()).to(device)
optimizer = torch.optim.Adam(competitor.parameters(), lr=LEARNING_RATE)

for ids, tgts in make_batches(MICRO_BATCH_SIZE):
    token_ids = torch.tensor(ids, device=device)
    targets = torch.tensor(tgts, device=device)
    loss, embedding = compute_loss(competitor, sig_reg, token_ids, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
```
    ```python
# Load competitor as starting point
rngs = nnx.Rngs(0)
competitor = Model.load_bytes(
    model_bytes, ModelConfig(dropout_rate=DROPOUT_RATE), rngs=rngs
)
competitor.train()

sig_reg = SIGReg(SIGRegConfig(), rngs)
optimizer = nnx.Optimizer(
    competitor, optax.adam(learning_rate=LEARNING_RATE), wrt=nnx.Param
)

@nnx.jit
def step(model, sig_reg, token_ids, targets):
    def loss_fn(model, sig_reg):
        return compute_loss(model, sig_reg, token_ids, targets)
    (loss, embedding), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, sig_reg)
    return loss, embedding, grads

for ids, tgts in make_batches(MICRO_BATCH_SIZE):
    loss, embedding, grads = step(competitor, sig_reg, jnp.array(ids), jnp.array(tgts))
    optimizer.update(competitor, grads)
```
    This is the simplest approach — you skip the cold-start phase and begin training from a proven set of weights.

### Knowledge distillation

Run both your model and a competitor forward on the same batches. Combine the standard SIGReg loss with a distillation term that pulls your model's representations toward the competitor's:

```python
teacher = Model.load_bytes(model_bytes, ModelConfig(dropout_rate=0.0)).to(device)
teacher.eval()

student = Model(ModelConfig(dropout_rate=DROPOUT_RATE)).to(device)
student.train()
sig_reg = SIGReg(SIGRegConfig()).to(device)

alpha = 0.5  # balance between task loss and distillation loss

for ids, tgts in make_batches(MICRO_BATCH_SIZE):
    token_ids = torch.tensor(ids, device=device)
    targets = torch.tensor(tgts, device=device)

    # Student's standard loss
    task_loss, student_embed = compute_loss(student, sig_reg, token_ids, targets)

    # Teacher's embedding (no grad)
    with torch.no_grad():
        _, teacher_embed = compute_loss(teacher, sig_reg, token_ids, targets)

    # Distillation: pull student embedding toward teacher
    distill_loss = torch.nn.functional.mse_loss(student_embed, teacher_embed)

    loss = alpha * task_loss + (1 - alpha) * distill_loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
```
    ```python
rngs = nnx.Rngs(0)
teacher = Model.load_bytes(
    model_bytes, ModelConfig(dropout_rate=0.0), rngs=rngs
)
teacher.eval()

student = Model(ModelConfig(dropout_rate=DROPOUT_RATE), rngs)
student.train()
sig_reg = SIGReg(SIGRegConfig(), rngs)
optimizer = nnx.Optimizer(
    student, optax.adam(learning_rate=LEARNING_RATE), wrt=nnx.Param
)

alpha = 0.5

@nnx.jit
def distill_step(student, teacher, sig_reg, token_ids, targets):
    def loss_fn(student, sig_reg):
        task_loss, student_embed = compute_loss(student, sig_reg, token_ids, targets)
        _, teacher_embed = compute_loss(teacher, sig_reg, token_ids, targets)
        distill_loss = jnp.mean((student_embed - teacher_embed) ** 2)
        loss = alpha * task_loss + (1 - alpha) * distill_loss
        return loss, student_embed
    (loss, _), grads = nnx.value_and_grad(loss_fn, has_aux=True)(student, sig_reg)
    return loss, grads

for ids, tgts in make_batches(MICRO_BATCH_SIZE):
    loss, grads = distill_step(
        student, teacher, sig_reg, jnp.array(ids), jnp.array(tgts)
    )
    optimizer.update(student, grads)
```
    ### Weight averaging

Merge your weights with a competitor's using linear interpolation:

```python
        beta = 0.3  # how much of the competitor to mix in

        for p_mine, p_theirs in zip(my_model.parameters(), competitor.parameters()):
            p_mine.data.lerp_(p_theirs.data, beta)
        ```
    ```python
        beta = 0.3  # how much of the competitor to mix in

        my_state = nnx.state(my_model)
        their_state = nnx.state(competitor)
        merged = jax.tree.map(
            lambda mine, theirs: (1 - beta) * mine + beta * theirs,
            my_state, their_state,
        )
        nnx.update(my_model, merged)
        ```
    ### Domain gap analysis

Evaluate your model on filled targets' data to find domains where you underperform, then focus training on those gaps:

```python
targets = await client.get_targets(status="filled", limit=50)

for target in targets:
    data = await client.fetch_submission_data(target.id)
    # Score your model vs the winning model on this data
    # High loss = domain gap worth training on
```

## Your model's embedding

The embedding you register determines which [targets](https://docs.soma.org/concepts/targets/#model-assignment) your model competes for. The KNN router assigns each target to the nearest model embeddings, so your position in embedding space controls both your target volume and your competition level.

Models clustered together share targets and compete purely on weight quality. Models in sparse regions get more targets with less competition. But you can't bluff your position — if your embedding is far from your actual strength, you'll receive targets but lose them (high loss). The dominant strategy is to **specialize honestly**: find an underserved region, train on data from that domain, and register an embedding that reflects where your model actually performs well.

### Genesis embedding

Compute your initial embedding from a representative sample of your training data:

1. Sample 256 random sequences of full context length (1024 bytes) from your training corpus
2. For each sequence, forward pass through your model and mean-pool the final layer output across byte positions → one 2048-dim vector per sequence
3. Average all 256 vectors, then L2-normalize

### Updating your embedding

After your first 100 competitive wins (or 7 days since last update, whichever comes first), recompute:

1. Forward pass on all won data from the period, mean-pool each sequence's final-layer output across positions
2. Average all resulting vectors, L2-normalize
3. Re-register via `commit_model`

Repeat on the same cadence. Each update pulls your embedding toward where your model is actually winning — reinforcing your specialization.

### Finding gaps

Query the registry to find sparse regions of embedding space where few models compete:

```python
state = await client.get_latest_system_state()

for model in state.model_registry:
    print(model.id, model.embedding)
```

Position your model in an underserved region, then focus training data on the corresponding domains. This reduces direct competition and increases your share of targets in that region.

## Recommended datasets

The current V1 architecture has a context length of **1024 bytes**. Documents are chunked into independent sequences — the model doesn't see cross-sequence context. This means shorter, self-contained passages (functions, docstrings, paragraphs) are more effective than long documents.

Use the same datasets for training as you would for [data submission](https://docs.soma.org/guides/data-strategies/#recommended-datasets):

- [The Stack v2](https://huggingface.co/datasets/bigcode/the-stack-v2-dedup) — the base corpus (default in quickstart)
- [StarCoderData](https://huggingface.co/datasets/bigcode/starcoderdata) — higher quality filtering than raw Stack v2
- [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) — natural language grounding for byte-level English
- [SWE-bench](https://huggingface.co/datasets/princeton-nlp/SWE-bench) — real GitHub issues paired with code fixes
- LLM-generated synthetic data — see [Data Strategies: LLM distillation](https://docs.soma.org/guides/data-strategies/#llm-distillation-generative)

## Next steps

[Data Strategies](https://docs.soma.org/guides/data-strategies/)

[Continuous Training](https://docs.soma.org/guides/model-development/)