# Models Reference

## Installation

`pip install soma-models`

PyTorch and Flax are optional extras. Install whichever framework you need.

```
pip install soma-models[torch]  # PyTorch
pip install soma-models[flax]   # Flax/JAX
```

## Configuration

Five dataclass configs control model architecture. All have sensible defaults matching the V1 architecture.

### ModelConfig

Top-level model configuration.

```python
from soma_models.v1.configs import ModelConfig

config = ModelConfig(dropout_rate=0.1)
```

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `dropout_rate` | `float` | - | Dropout probability (required) |
| `embedding_dim` | `int` | `2048` | Token embedding dimension |
| `pwff_hidden_dim` | `int` | `8192` | Feed-forward hidden dimension |
| `num_layers` | `int` | `24` | Number of encoder layers |
| `num_heads` | `int` | `8` | Number of attention heads |
| `vocab_size` | `int` | `264` | Vocabulary size (256 bytes + 8 special tokens) |
| `max_wavelength` | `float` | `10000.0` | RoPE max wavelength |
| `scale_factor` | `float` | `1.0` | RoPE scale factor |

### EncoderConfig

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `dropout_rate` | `float` | - | Dropout probability (required) |
| `embedding_dim` | `int` | `2048` | Embedding dimension |
| `pwff_hidden_dim` | `int` | `8192` | Feed-forward hidden dimension |
| `num_layers` | `int` | `24` | Number of layers |
| `num_heads` | `int` | `8` | Attention heads |
| `max_wavelength` | `float` | `10000.0` | RoPE max wavelength |
| `scale_factor` | `float` | `1.0` | RoPE scale factor |

### LayerConfig

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `dropout_rate` | `float` | - | Dropout probability (required) |
| `embedding_dim` | `int` | `2048` | Embedding dimension |
| `pwff_hidden_dim` | `int` | `8192` | Feed-forward hidden dimension |
| `num_heads` | `int` | `8` | Attention heads |
| `max_wavelength` | `float` | `10000.0` | RoPE max wavelength |
| `scale_factor` | `float` | `1.0` | RoPE scale factor |

### PositionWiseFeedForwardConfig

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `dropout_rate` | `float` | - | Dropout probability (required) |
| `embedding_dim` | `int` | `2048` | Input/output dimension |
| `pwff_hidden_dim` | `int` | `8192` | Hidden dimension |

### SIGRegConfig

Spectral Isotropy Gaussian Regularization config.

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `t_max` | `float` | `3.0` | Maximum integration bound |
| `points` | `int` | `17` | Number of quadrature points |
| `slices` | `int` | `256` | Number of random projections |
| `coefficient` | `float` | `0.02` | Regularization weight |

## Constants

| Constant | Value | Description |
|----------|-------|-------------|
| `V1_EMBEDDING_DIM` | `2048` | Token embedding dimension |
| `V1_NUM_HEADS` | `8` | Attention heads |
| `V1_NUM_LAYERS` | `24` | Encoder layers |
| `V1_MAX_SEQ_LEN` | `1024` | Maximum sequence length |
| `V1_VOCAB_SIZE` | `264` | 256 byte values + 8 special tokens |
| `V1_PWFF_HIDDEN_DIM` | `8192` | Feed-forward hidden dimension |
| `V1_MAX_WAVELENGTH` | `10000.0` | RoPE max wavelength |
| `V1_SCALE_FACTOR` | `1.0` | RoPE scale factor |
| `V1_PAD_TOKEN_ID` | `256` | Padding token ID |
| `V1_EOS_TOKEN_ID` | `257` | End-of-sequence token ID |
| `V1_SIG_REG_T_MAX` | `3.0` | SIGReg integration bound |
| `V1_SIG_REG_SLICES` | `256` | SIGReg random projections |
| `V1_SIG_REG_POINTS` | `17` | SIGReg quadrature points |
| `V1_SIG_REG_COEFFICIENT` | `0.02` | SIGReg weight |
| `V1_BATCH_SIZE` | `16` | Default batch size |
| `ARCHITECTURE_VERSION` | `1` | Must match network `model_architecture_version` |

## Tokenizer

Framework-agnostic byte-level tokenizer matching the network's V1 data contract.

### tokenize

Convert raw bytes into a list of tokenized sequences.

```python
from soma_models.v1.tokenizer import tokenize

sequences = tokenize(data, max_seq_len=1024)
```

| Parameter | Type | Required | Default | Description |
|-----------|------|----------|---------|-------------|
| `data` | `bytes \| bytearray` | Yes | - | Raw byte data to tokenize |
| `max_seq_len` | `int` | No | `1024` | Maximum sequence length per chunk |

**Returns**: `list[TokenizedSequence]`. One sequence per chunk of the input data.

### TokenizedSequence

| Field | Type | Description |
|-------|------|-------------|
| `token_ids` | `list[int]` | `[seq_len]`, input token IDs |
| `targets` | `list[int]` | `[seq_len]`, next-token targets (shifted left, PAD appended) |
| `pos_ids` | `list[int]` | `[seq_len]`, global byte offsets |

## Model Components

```python
from soma_models.v1.torch import Model, ModelConfig, SIGReg, SIGRegConfig, compute_loss
```

### Model

Inherits `Serializable` and `nn.Module`. Full transformer: embedding → encoder → final norm → predictor.

```python
model = Model(ModelConfig(dropout_rate=0.1))
```

| Method | Signature | Description |
|--------|-----------|-------------|
| `encode` | `(input: Tensor, positions: Tensor) -> Tensor` | Embed + encode, returns `[batch, seq, dim]` |
| `predict` | `(embeddings: Tensor) -> Tensor` | Linear projection to vocab logits |
| `forward` | `(input: Tensor, positions: Tensor) -> Tensor` | `predict(encode(input, positions))` |

### Encoder

Stack of `Layer` modules.

```python
from soma_models.v1.torch.modules.encoder import Encoder
```

| Method | Signature | Description |
|--------|-----------|-------------|
| `forward` | `(input: Tensor, positions: Tensor) -> Tensor` | Pass through all layers sequentially |

### Layer

Pre-norm transformer layer: LayerNorm → Attention → residual → LayerNorm → FeedForward → residual.

```python
from soma_models.v1.torch.modules.layer import Layer
```

### MultiHeadAttention

Causal multi-head attention with Rotary Position Embeddings (RoPE).

```python
from soma_models.v1.torch.modules.attention import MultiHeadAttention
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `num_heads` | `int` | - | Number of attention heads |
| `num_features` | `int` | - | Input/output dimension |
| `dropout_rate` | `float` | `0.0` | Attention dropout |
| `use_bias` | `bool` | `True` | Use bias in projections |
| `max_wavelength` | `float` | `10000.0` | RoPE max wavelength |
| `scale_factor` | `float` | `1.0` | RoPE scale factor |

### PositionWiseFeedForward

Two-layer feed-forward with GELU activation.

```python
from soma_models.v1.torch.modules.pwff import PositionWiseFeedForward
```

### SIGReg

Spectral Isotropy Gaussian Regularization. Penalizes anisotropy in representations via random Gaussian projections.

```python
sig_reg = SIGReg(SIGRegConfig())
loss = sig_reg(representations)  # scalar tensor
```

| Method | Signature | Description |
|--------|-----------|-------------|
| `forward` | `(x: Tensor) -> Tensor` | Compute regularization loss (generates noise internally) |
| `compute` | `(x: Tensor, noise: Tensor) -> Tensor` | Compute with explicit noise tensor |

```python
from soma_models.v1.flax import Model, ModelConfig, SIGReg, SIGRegConfig, compute_loss
```

Same classes and interfaces as PyTorch, implemented with Flax `nn.Module`. All method signatures and behavior are identical.

## Loss

### compute_loss

Compute training loss: cross-entropy (ignoring PAD tokens) plus SIGReg regularization.

```python
from soma_models.v1.torch import compute_loss

loss, embedding = compute_loss(model, sig_reg, token_ids, targets)
```

| Parameter | Type | Description |
|-----------|------|-------------|
| `model` | `Model` | The V1 Model instance |
| `sig_reg` | `SIGReg` | SIGReg module (generates noise via global RNG) |
| `token_ids` | `Tensor` | Input token IDs, shape `[batch, seq]` |
| `targets` | `Tensor` | Next-token targets, shape `[batch, seq]` |

**Returns**: `(loss, embedding)` where `loss` is a scalar tensor (CE + SIGReg) and `embedding` is the mean embedding of shape `[embedding_dim]`.

## Serialization

Model weights are stored in safetensors format. The canonical format differs from PyTorch conventions. The `Serde` class handles the conversion automatically.

### Serializable

Mixin for `nn.Module` subclasses. The `Model` class inherits this.

```python
# Save
model.save("weights.safetensors")
data = model.save_bytes()

# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))
model = Model.load_bytes(data, ModelConfig(dropout_rate=0.0))
```

| Method | Signature | Description |
|--------|-----------|-------------|
| `save` | `(filename) -> None` | Save weights to safetensors file |
| `save_bytes` | `() -> bytes` | Serialize weights to bytes |
| `load` | `(filename, *args, **kwargs) -> cls` | Class method: load from file |
| `load_bytes` | `(data: bytes, *args, **kwargs) -> cls` | Class method: load from bytes |

### Serde

Low-level serialization. Handles LayerNorm key remapping (`weight`/`bias` ↔ `gamma`/`beta`) and linear weight transposition (row-major ↔ column-major).

```python
from soma_models.v1.torch.serde import Serde

serde = Serde(model)
data = serde.serialize()
serde.deserialize(data)
```

## Utilities

```python
from soma_models.utils import remap, flatten_dict, unflatten_dict
```

### remap

Rename and delete keys in a dictionary, in-place.

```python
remap(d, rename_map={"old": "new"}, keys_to_delete=["unwanted"])
```

### flatten_dict

Flatten a nested dictionary into dotted keys.

```python
flatten_dict({"a": {"b": 1}})  # {"a.b": 1}
```

### unflatten_dict

Reverse of `flatten_dict`.

```python
unflatten_dict({"a.b": 1})  # {"a": {"b": 1}}
```