Skip to content

trie

octonion.trie

Self-organizing octonionic trie.

A hierarchical memory structure where routing uses Fano plane subalgebra decomposition, growth is triggered by associator incompatibility, and updates use octonionic composition. No gradient computation required.

Each node stores
  • routing_key: fixed at creation, determines how inputs route through this node
  • content: accumulated via composition, represents the node's knowledge
  • category_counts: tracks which categories have been routed here (for evaluation)
Example

from octonion.trie import OctonionTrie trie = OctonionTrie(associator_threshold=0.3) trie.insert(some_octonion, category=0) leaf = trie.query(some_octonion) leaf.dominant_category 0

TrieNode dataclass

A node in the octonionic trie.

Source code in src/octonion/trie.py
@dataclass
class TrieNode:
    """A node in the octonionic trie."""

    routing_key: torch.Tensor
    content: torch.Tensor
    children: dict[int, TrieNode] = field(default_factory=dict)
    subalgebra_idx: int | None = None
    insert_count: int = 0
    category_counts: dict[int, int] = field(default_factory=dict)
    depth: int = 0
    buffer: deque = field(default_factory=lambda: deque(maxlen=30))
    _policy_state: dict = field(default_factory=dict)

    @property
    def dominant_category(self) -> int | None:
        """Category with the most inserts at this node."""
        if not self.category_counts:
            return None
        return max(self.category_counts, key=lambda k: self.category_counts[k])

    @property
    def is_leaf(self) -> bool:
        return len(self.children) == 0

dominant_category property

Category with the most inserts at this node.

ThresholdPolicy

Bases: ABC

Abstract base class for trie threshold strategies.

The OctonionTrie delegates all threshold decisions to a ThresholdPolicy. This decouples the trie's self-organization logic from how thresholds are determined, enabling pluggable adaptation strategies.

Source code in src/octonion/trie.py
class ThresholdPolicy(ABC):
    """Abstract base class for trie threshold strategies.

    The OctonionTrie delegates all threshold decisions to a ThresholdPolicy.
    This decouples the trie's self-organization logic from how thresholds
    are determined, enabling pluggable adaptation strategies.
    """

    @abstractmethod
    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        """Return the associator norm threshold for routing at this node.

        Args:
            node: The child node being evaluated for compatibility.
            depth: Depth of the parent node.
            parent: The parent node (None only for root-level queries).
                Policies that need to compute associators from buffer entries
                require the parent to form the correct triple [buf, child, parent].
        """
        ...

    @abstractmethod
    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        """Return the similarity threshold for rumination at this node."""
        ...

    @abstractmethod
    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        """Return (min_share, min_count) for consolidation at this node."""
        ...

    def on_insert(self, node: TrieNode, x: torch.Tensor, assoc_norm: float) -> None:
        """Optional hook called after each insertion for policy updates."""
        pass

get_assoc_threshold(node, depth, parent=None) abstractmethod

Return the associator norm threshold for routing at this node.

Parameters:

Name Type Description Default
node TrieNode

The child node being evaluated for compatibility.

required
depth int

Depth of the parent node.

required
parent TrieNode | None

The parent node (None only for root-level queries). Policies that need to compute associators from buffer entries require the parent to form the correct triple [buf, child, parent].

None
Source code in src/octonion/trie.py
@abstractmethod
def get_assoc_threshold(
    self, node: TrieNode, depth: int, parent: TrieNode | None = None
) -> float:
    """Return the associator norm threshold for routing at this node.

    Args:
        node: The child node being evaluated for compatibility.
        depth: Depth of the parent node.
        parent: The parent node (None only for root-level queries).
            Policies that need to compute associators from buffer entries
            require the parent to form the correct triple [buf, child, parent].
    """
    ...

get_sim_threshold(node, depth) abstractmethod

Return the similarity threshold for rumination at this node.

Source code in src/octonion/trie.py
@abstractmethod
def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
    """Return the similarity threshold for rumination at this node."""
    ...

get_consolidation_params(node, depth) abstractmethod

Return (min_share, min_count) for consolidation at this node.

Source code in src/octonion/trie.py
@abstractmethod
def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
    """Return (min_share, min_count) for consolidation at this node."""
    ...

on_insert(node, x, assoc_norm)

Optional hook called after each insertion for policy updates.

Source code in src/octonion/trie.py
def on_insert(self, node: TrieNode, x: torch.Tensor, assoc_norm: float) -> None:
    """Optional hook called after each insertion for policy updates."""
    pass

GlobalPolicy

Bases: ThresholdPolicy

Global (fixed) threshold policy -- reproduces original hardcoded behavior.

All nodes at all depths use the same thresholds. This is the baseline policy and the default when no explicit policy is provided.

Source code in src/octonion/trie.py
class GlobalPolicy(ThresholdPolicy):
    """Global (fixed) threshold policy -- reproduces original hardcoded behavior.

    All nodes at all depths use the same thresholds. This is the baseline
    policy and the default when no explicit policy is provided.
    """

    def __init__(
        self,
        assoc_threshold: float = 0.3,
        sim_threshold: float = 0.1,
        min_share: float = 0.05,
        min_count: int = 3,
    ):
        self.assoc_threshold = assoc_threshold
        self.sim_threshold = sim_threshold
        self.min_share = min_share
        self.min_count = min_count

    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        return self.assoc_threshold

    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        return self.sim_threshold

    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        return self.min_share, self.min_count

PerNodeEMAPolicy

Bases: ThresholdPolicy

Per-node EMA of observed associator norms.

Each node maintains an exponential moving average of associator norms seen during insertion. The threshold adapts to mean + k * std of the local distribution. Falls back to base threshold until the node has accumulated enough observations (min_obs).

Per-node state keys: ema_mean, ema_var, ema_count

Source code in src/octonion/trie.py
class PerNodeEMAPolicy(ThresholdPolicy):
    """Per-node EMA of observed associator norms.

    Each node maintains an exponential moving average of associator norms
    seen during insertion. The threshold adapts to mean + k * std of the
    local distribution. Falls back to base threshold until the node has
    accumulated enough observations (min_obs).

    Per-node state keys: ema_mean, ema_var, ema_count
    """

    def __init__(
        self,
        alpha: float = 0.1,
        k: float = 1.5,
        base_assoc: float = 0.3,
        sim_threshold: float = 0.1,
        min_share: float = 0.05,
        min_count: int = 3,
        min_obs: int = 3,
    ):
        self.alpha = alpha
        self.k = k
        self.base_assoc = base_assoc
        self.sim_threshold = sim_threshold
        self.min_share = min_share
        self.min_count = min_count
        self.min_obs = min_obs

    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        count = node._policy_state.get("ema_count", 0)
        if count < self.min_obs:
            return self.base_assoc
        mean = node._policy_state["ema_mean"]
        var = node._policy_state["ema_var"]
        std = math.sqrt(max(var, 0.0))
        return mean + self.k * std

    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        return self.sim_threshold

    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        return self.min_share, self.min_count

    def on_insert(self, node: TrieNode, x: torch.Tensor, assoc_norm: float) -> None:
        count = node._policy_state.get("ema_count", 0)
        if count == 0:
            node._policy_state["ema_mean"] = assoc_norm
            node._policy_state["ema_var"] = 0.0
            node._policy_state["ema_count"] = 1
        else:
            old_mean = node._policy_state["ema_mean"]
            new_mean = (1 - self.alpha) * old_mean + self.alpha * assoc_norm
            diff = assoc_norm - old_mean
            new_var = (1 - self.alpha) * node._policy_state["ema_var"] + self.alpha * diff * diff
            node._policy_state["ema_mean"] = new_mean
            node._policy_state["ema_var"] = new_var
            node._policy_state["ema_count"] = count + 1

PerNodeMeanStdPolicy

Bases: ThresholdPolicy

Per-node running mean + std using Welford's online algorithm.

Like PerNodeEMAPolicy but uses unweighted running statistics -- all observations contribute equally regardless of order. The threshold adapts to mean + k * std after sufficient observations.

Per-node state keys: welford_mean, welford_M2, welford_count

Source code in src/octonion/trie.py
class PerNodeMeanStdPolicy(ThresholdPolicy):
    """Per-node running mean + std using Welford's online algorithm.

    Like PerNodeEMAPolicy but uses unweighted running statistics --
    all observations contribute equally regardless of order. The threshold
    adapts to mean + k * std after sufficient observations.

    Per-node state keys: welford_mean, welford_M2, welford_count
    """

    def __init__(
        self,
        k: float = 1.5,
        base_assoc: float = 0.3,
        sim_threshold: float = 0.1,
        min_share: float = 0.05,
        min_count: int = 3,
        min_obs: int = 3,
    ):
        self.k = k
        self.base_assoc = base_assoc
        self.sim_threshold = sim_threshold
        self.min_share = min_share
        self.min_count = min_count
        self.min_obs = min_obs

    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        count = node._policy_state.get("welford_count", 0)
        if count < self.min_obs:
            return self.base_assoc
        mean = node._policy_state["welford_mean"]
        M2 = node._policy_state["welford_M2"]
        var = M2 / count
        std = math.sqrt(max(var, 0.0))
        return mean + self.k * std

    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        return self.sim_threshold

    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        return self.min_share, self.min_count

    def on_insert(self, node: TrieNode, x: torch.Tensor, assoc_norm: float) -> None:
        count = node._policy_state.get("welford_count", 0)
        if count == 0:
            node._policy_state["welford_mean"] = assoc_norm
            node._policy_state["welford_M2"] = 0.0
            node._policy_state["welford_count"] = 1
        else:
            count += 1
            old_mean = node._policy_state["welford_mean"]
            delta = assoc_norm - old_mean
            new_mean = old_mean + delta / count
            delta2 = assoc_norm - new_mean
            new_M2 = node._policy_state["welford_M2"] + delta * delta2
            node._policy_state["welford_mean"] = new_mean
            node._policy_state["welford_M2"] = new_M2
            node._policy_state["welford_count"] = count

DepthPolicy

Bases: ThresholdPolicy

Depth-dependent threshold: threshold = base * decay_factor ^ depth.

decay_factor < 1: thresholds tighten with depth (deeper = stricter). decay_factor > 1: thresholds loosen with depth (deeper = more tolerant). decay_factor = 1: equivalent to GlobalPolicy.

Source code in src/octonion/trie.py
class DepthPolicy(ThresholdPolicy):
    """Depth-dependent threshold: threshold = base * decay_factor ^ depth.

    decay_factor < 1: thresholds tighten with depth (deeper = stricter).
    decay_factor > 1: thresholds loosen with depth (deeper = more tolerant).
    decay_factor = 1: equivalent to GlobalPolicy.
    """

    def __init__(
        self,
        base_assoc: float = 0.3,
        decay_factor: float = 1.0,
        sim_threshold: float = 0.1,
        min_share: float = 0.05,
        min_count: int = 3,
    ):
        self.base_assoc = base_assoc
        self.decay_factor = decay_factor
        self.sim_threshold = sim_threshold
        self.min_share = min_share
        self.min_count = min_count

    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        return self.base_assoc * (self.decay_factor ** depth)

    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        return self.sim_threshold

    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        return self.min_share, self.min_count

AlgebraicPurityPolicy

Bases: ThresholdPolicy

Threshold based on algebraic purity of the node's buffer.

Uses two independent signals from the node's buffer: (a) Variance of associator norms between buffer entries and the routing key. (b) Variance of inner products between buffer entries and the routing key.

Low variance = high algebraic purity = can tighten threshold. High variance = heterogeneous content = should loosen threshold.

threshold = base * (1 + sensitivity * combined_signal) combined_signal = assoc_weight * norm_variance + sim_weight * sim_variance

Source code in src/octonion/trie.py
class AlgebraicPurityPolicy(ThresholdPolicy):
    """Threshold based on algebraic purity of the node's buffer.

    Uses two independent signals from the node's buffer:
    (a) Variance of associator norms between buffer entries and the routing key.
    (b) Variance of inner products between buffer entries and the routing key.

    Low variance = high algebraic purity = can tighten threshold.
    High variance = heterogeneous content = should loosen threshold.

    threshold = base * (1 + sensitivity * combined_signal)
    combined_signal = assoc_weight * norm_variance + sim_weight * sim_variance
    """

    def __init__(
        self,
        base_assoc: float = 0.3,
        assoc_weight: float = 0.5,
        sim_weight: float = 0.5,
        sensitivity: float = 2.0,
        sim_threshold: float = 0.1,
        min_share: float = 0.05,
        min_count: int = 3,
    ):
        self.base_assoc = base_assoc
        self.assoc_weight = assoc_weight
        self.sim_weight = sim_weight
        self.sensitivity = sensitivity
        self.sim_threshold = sim_threshold
        self.min_share = min_share
        self.min_count = min_count

    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        if len(node.buffer) < 3:
            return self.base_assoc

        # Compute associator norm variance across buffer entries.
        # The routing triple is [input, child, parent], so the analogous
        # measurement for buffer entries is [buf, child, parent].
        # When parent is unavailable (root query), fall back to base threshold.
        if parent is None:
            return self.base_assoc

        child_oct = Octonion(node.routing_key)
        parent_oct = Octonion(parent.routing_key)
        assoc_norms = []
        sim_values = []
        for buf_x, _ in node.buffer:
            buf_oct = Octonion(buf_x)
            a = associator(buf_oct, child_oct, parent_oct)
            assoc_norms.append(a.components.norm().item())
            sim_values.append(torch.dot(buf_x, node.routing_key).item())

        # Variance of associator norms
        if len(assoc_norms) > 1:
            mean_a = sum(assoc_norms) / len(assoc_norms)
            var_a = sum((v - mean_a) ** 2 for v in assoc_norms) / len(assoc_norms)
        else:
            var_a = 0.0

        # Variance of similarity values
        if len(sim_values) > 1:
            mean_s = sum(sim_values) / len(sim_values)
            var_s = sum((v - mean_s) ** 2 for v in sim_values) / len(sim_values)
        else:
            var_s = 0.0

        combined = self.assoc_weight * var_a + self.sim_weight * var_s
        return self.base_assoc * (1.0 + self.sensitivity * combined)

    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        return self.sim_threshold

    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        return self.min_share, self.min_count

MetaTriePolicy

Bases: ThresholdPolicy

Meta-trie optimizer: a second OctonionTrie adapts classifier thresholds.

Per D-12: Uses the same OctonionTrie class (not a subclass). Per D-13: Categories are discretized threshold actions (compounding). Per D-14: Two input encoding modes. Per D-02: All feedback is unsupervised — ratio of mean_norm/threshold.

Design: Thresholds compound multiplicatively. Each action multiplies the node's current threshold by a factor. Over successive actions, thresholds walk toward the optimal value discovered by the ratio feedback signal (mean_assoc_norm / current_threshold).

The ratio signal is self-referential: the trie measures how its own routing decisions relate to its own thresholds. No external labels.

Source code in src/octonion/trie.py
class MetaTriePolicy(ThresholdPolicy):
    """Meta-trie optimizer: a second OctonionTrie adapts classifier thresholds.

    Per D-12: Uses the same OctonionTrie class (not a subclass).
    Per D-13: Categories are discretized threshold actions (compounding).
    Per D-14: Two input encoding modes.
    Per D-02: All feedback is unsupervised — ratio of mean_norm/threshold.

    Design: Thresholds compound multiplicatively. Each action multiplies
    the node's current threshold by a factor. Over successive actions,
    thresholds walk toward the optimal value discovered by the ratio
    feedback signal (mean_assoc_norm / current_threshold).

    The ratio signal is self-referential: the trie measures how its own
    routing decisions relate to its own thresholds. No external labels.

    Feedback loop:
      1. ACT:     Query meta-trie → get recommended action
      2. APPLY:   Multiply node threshold by action factor
      3. OBSERVE: Collect assoc_norms, compute post-action ratio
      4. LEARN:   Did ratio move toward target? → insert outcome
    """

    # Compounding multiplicative factors per D-13
    ACTIONS = {
        0: 0.7,   # tighten fast
        1: 0.9,   # tighten slow
        2: 1.0,   # keep
        3: 1.1,   # loosen slow
        4: 1.4,   # loosen fast
    }
    _OPPOSITE = {0: 4, 1: 3, 2: 2, 3: 1, 4: 0}
    # Target ratio: mean_norm / threshold ≈ 0.5 means threshold is well-calibrated
    _TARGET_RATIO = 0.5

    def __init__(
        self,
        base_assoc: float = 0.3,
        sim_threshold: float = 0.1,
        min_share: float = 0.05,
        min_count: int = 3,
        signal_encoding: str = "algebraic",       # or "signal_vector" per D-14
        update_frequency: int = 10,               # per D-16: per-N-compatible-routings
        observation_window: int = 5,              # samples before evaluating
        exploration_rate: float = 0.2,            # initial epsilon
        exploration_decay: float = 0.995,         # per-update-event decay
        exploration_min: float = 0.01,            # floor
        generalize_every: int = 10,               # sweep every N updates
        generalize_fraction: float = 0.3,         # fraction per sweep
        self_referential: bool = False,            # per D-17
        meta_seed: int = 7919,
    ):
        self.base_assoc = base_assoc
        self.sim_threshold = sim_threshold
        self.min_share = min_share
        self.min_count = min_count
        self.signal_encoding = signal_encoding
        self.update_frequency = update_frequency
        self.observation_window = observation_window
        self._exploration_rate = exploration_rate
        self.exploration_decay = exploration_decay
        self.exploration_min = exploration_min
        self.generalize_every = generalize_every
        self.generalize_fraction = generalize_fraction
        self.self_referential = self_referential

        self.meta_trie = OctonionTrie(
            associator_threshold=base_assoc,
            similarity_threshold=sim_threshold,
            seed=meta_seed,
        )

        self._insert_counter = 0
        self._update_counter = 0
        self._convergence_history: list[float] = []
        self._prev_thresholds: dict[int, float] = {}
        self._rng = torch.Generator().manual_seed(meta_seed + 1)
        self._id_to_node: weakref.WeakValueDictionary[int, TrieNode] = (
            weakref.WeakValueDictionary()
        )
        self._meta_outcomes: list[bool] = []

    @property
    def exploration_rate(self) -> float:
        return self._exploration_rate

    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        return node._policy_state.get("meta_threshold", self.base_assoc)

    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        return self.sim_threshold

    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        return self.min_share, self.min_count

    def on_insert(self, node: TrieNode, x: torch.Tensor, assoc_norm: float) -> None:
        state = node._policy_state
        self._id_to_node[id(node)] = node

        state.setdefault("meta_obs_norms", []).append(assoc_norm)
        if len(state["meta_obs_norms"]) > 200:
            state["meta_obs_norms"] = state["meta_obs_norms"][-200:]
        if "meta_action_taken" in state:
            state.setdefault("meta_post_norms", []).append(assoc_norm)

        post_norms = state.get("meta_post_norms", [])
        if len(post_norms) >= self.observation_window and "meta_action_taken" in state:
            self._evaluate_and_learn(node)

        self._insert_counter += 1
        if self._insert_counter % self.update_frequency == 0:
            if "meta_action_taken" in state and state.get("meta_post_norms"):
                self._evaluate_and_learn(node)
            self._act_on_node(node)
            self._update_counter += 1
            if self._update_counter % self.generalize_every == 0:
                self._generalize_sweep()
            self._track_convergence()
            if self.self_referential:
                self._adapt_meta_trie_threshold()
            self._exploration_rate = max(
                self.exploration_min,
                self._exploration_rate * self.exploration_decay,
            )
            if self._update_counter % (self.generalize_every * 5) == 0:
                self._prune_stale()

    def _encode(self, node: TrieNode) -> torch.Tensor:
        """Encode node state as octonion for meta-trie input."""
        if self.signal_encoding == "algebraic":
            return node.routing_key.clone()
        state = node._policy_state
        norms = state.get("meta_obs_norms", [0.0])
        norms_t = torch.tensor(norms[-30:], dtype=torch.float64)
        thresh = state.get("meta_threshold", self.base_assoc)
        mean_norm = norms_t.mean().item()
        ratio = mean_norm / max(thresh, 1e-10)
        return torch.tensor([
            min(ratio, 3.0),                     # norm/threshold ratio (key signal)
            min(norms_t.std().item() / max(mean_norm, 1e-10), 2.0) if len(norms) > 1 else 0.0,
            len(node.children) / 7.0,
            min(node.insert_count / 100.0, 2.0),
            min(thresh / self.base_assoc, 3.0),  # current threshold relative to base
            node.depth / 15.0,
            self._buffer_consistency(node),
            min(len(norms) / 30.0, 2.0),
        ], dtype=torch.float64)

    def _buffer_consistency(self, node: TrieNode) -> float:
        if len(node.buffer) < 2:
            return 0.0
        items = list(node.buffer)
        sims = []
        for i in range(min(5, len(items))):
            for j in range(i + 1, min(5, len(items))):
                sims.append(torch.dot(items[i][0], items[j][0]).item())
        return sum(sims) / len(sims) if sims else 0.0

    def _act_on_node(self, node: TrieNode) -> None:
        """ACT: query meta-trie, multiply node's threshold by recommended factor."""
        state = node._policy_state
        thresh = state.get("meta_threshold", self.base_assoc)

        # Record pre-action ratio for evaluation
        norms = state.get("meta_obs_norms", [])
        window = min(self.observation_window, len(norms))
        if window >= 3:
            pre_mean = sum(norms[-window:]) / window
            state["meta_pre_ratio"] = pre_mean / max(thresh, 1e-10)
        else:
            state["meta_pre_ratio"] = 1.0

        # Query meta-trie
        meta_input = self._encode(node)
        leaf = self.meta_trie.query(meta_input)
        recommended = leaf.dominant_category

        if (recommended is None
                or torch.rand(1, generator=self._rng).item() < self._exploration_rate):
            action = torch.randint(0, 5, (1,), generator=self._rng).item()
        else:
            action = recommended

        # Compound: multiply current threshold by action factor
        factor = self.ACTIONS[action]
        new_thresh = max(0.001, min(thresh * factor, 5.0))
        state["meta_threshold"] = new_thresh
        state["meta_action_taken"] = action
        state["meta_state_before"] = meta_input.clone()
        state["meta_post_norms"] = []

    def _evaluate_and_learn(self, node: TrieNode) -> None:
        """LEARN: did the ratio move toward the target (0.5)?

        ratio = mean_assoc_norm / threshold
        - ratio ≈ 0.5: threshold is well-calibrated
        - ratio << 0.5: threshold too loose (tighten helped if ratio increased)
        - ratio >> 0.5: threshold too tight (loosen helped if ratio decreased)

        This is purely self-referential: the trie evaluates its own threshold
        against its own routing behavior. No external labels.
        """
        state = node._policy_state
        action_taken = state.get("meta_action_taken")
        state_before = state.get("meta_state_before")
        pre_ratio = state.get("meta_pre_ratio", 1.0)

        if action_taken is None or state_before is None:
            self._clear_pending(state)
            return

        post_norms = state.get("meta_post_norms", [])
        if len(post_norms) < 3:
            self._clear_pending(state)
            return

        thresh = state.get("meta_threshold", self.base_assoc)
        post_mean = sum(post_norms) / len(post_norms)
        post_ratio = post_mean / max(thresh, 1e-10)

        # Did the ratio move closer to target?
        pre_dist = abs(pre_ratio - self._TARGET_RATIO)
        post_dist = abs(post_ratio - self._TARGET_RATIO)
        helped = post_dist < pre_dist - 0.02  # improved by at least 0.02
        hurt = post_dist > pre_dist + 0.02

        if helped:
            outcome_label = action_taken
        elif hurt:
            outcome_label = self._OPPOSITE[action_taken]
        else:
            outcome_label = 2

        self.meta_trie.insert(state_before, category=outcome_label)
        self._meta_outcomes.append(helped)
        if len(self._meta_outcomes) > 100:
            self._meta_outcomes = self._meta_outcomes[-100:]
        self._clear_pending(state)

    def _clear_pending(self, state: dict) -> None:
        state.pop("meta_action_taken", None)
        state.pop("meta_state_before", None)
        state.pop("meta_pre_ratio", None)
        state.pop("meta_post_norms", None)

    def _generalize_sweep(self) -> None:
        for node in list(self._id_to_node.values()):
            if node is None:
                continue
            state = node._policy_state
            if "meta_action_taken" not in state:
                continue
            if state.get("meta_post_norms"):
                self._evaluate_and_learn(node)
            else:
                self._clear_pending(state)

        eligible = [
            nid for nid, node in self._id_to_node.items()
            if node is not None and "meta_action_taken" not in node._policy_state
        ]
        n = max(1, int(len(eligible) * self.generalize_fraction))
        indices = torch.randperm(len(eligible), generator=self._rng)[:n]
        for idx in indices:
            node = self._id_to_node.get(eligible[idx.item()])
            if node is not None:
                self._act_on_node(node)

    def _adapt_meta_trie_threshold(self) -> None:
        if len(self._meta_outcomes) < 10:
            return
        hit_rate = sum(self._meta_outcomes[-50:]) / len(self._meta_outcomes[-50:])
        if hit_rate < 0.3:
            self.meta_trie.assoc_threshold = max(
                0.001, self.meta_trie.assoc_threshold * 1.05)
        elif hit_rate > 0.6:
            self.meta_trie.assoc_threshold = max(
                0.001, self.meta_trie.assoc_threshold * 0.95)

    def _track_convergence(self) -> None:
        curr = {}
        for nid, node in self._id_to_node.items():
            if node is not None:
                curr[nid] = node._policy_state.get("meta_threshold", self.base_assoc)
        if self._prev_thresholds:
            changes = [
                abs(curr.get(k, self.base_assoc) - self._prev_thresholds.get(k, self.base_assoc))
                for k in set(curr) | set(self._prev_thresholds)
            ]
            if changes:
                self._convergence_history.append(sum(changes) / len(changes))
        self._prev_thresholds = curr

    def _prune_stale(self) -> None:
        live = set(self._id_to_node.keys())
        for k in [k for k in self._prev_thresholds if k not in live]:
            del self._prev_thresholds[k]

    @property
    def converged(self) -> bool:
        if len(self._convergence_history) < 3:
            return False
        return self._convergence_history[-1] < 0.001

HybridPolicy

Bases: ThresholdPolicy

Combines two ThresholdPolicy instances per D-09.

Combination modes: - "mean": average of both policies' thresholds - "min": minimum (more conservative / tighter) - "max": maximum (more permissive / looser) - "adaptive": use policy_a in early epochs, transition to policy_b

Source code in src/octonion/trie.py
class HybridPolicy(ThresholdPolicy):
    """Combines two ThresholdPolicy instances per D-09.

    Combination modes:
    - "mean": average of both policies' thresholds
    - "min": minimum (more conservative / tighter)
    - "max": maximum (more permissive / looser)
    - "adaptive": use policy_a in early epochs, transition to policy_b
    """

    def __init__(
        self,
        policy_a: ThresholdPolicy | None = None,
        policy_b: ThresholdPolicy | None = None,
        combination: str = "mean",
        transition_inserts: int = 0,  # for "adaptive" mode: switch after N inserts
    ):
        self.policy_a = policy_a if policy_a is not None else GlobalPolicy()
        self.policy_b = policy_b if policy_b is not None else GlobalPolicy()
        self.combination = combination
        self.transition_inserts = transition_inserts
        self._total_inserts = 0

    def _combine(self, val_a: float, val_b: float) -> float:
        if self.combination == "mean":
            return (val_a + val_b) / 2.0
        elif self.combination == "min":
            return min(val_a, val_b)
        elif self.combination == "max":
            return max(val_a, val_b)
        elif self.combination == "adaptive":
            # Smooth transition from policy_a to policy_b
            if self.transition_inserts <= 0:
                return val_b
            alpha = min(1.0, self._total_inserts / self.transition_inserts)
            return (1 - alpha) * val_a + alpha * val_b
        return (val_a + val_b) / 2.0

    def get_assoc_threshold(
        self, node: TrieNode, depth: int, parent: TrieNode | None = None
    ) -> float:
        return self._combine(
            self.policy_a.get_assoc_threshold(node, depth, parent),
            self.policy_b.get_assoc_threshold(node, depth, parent),
        )

    def get_sim_threshold(self, node: TrieNode, depth: int) -> float:
        return self._combine(
            self.policy_a.get_sim_threshold(node, depth),
            self.policy_b.get_sim_threshold(node, depth),
        )

    def get_consolidation_params(self, node: TrieNode, depth: int) -> tuple[float, int]:
        ms_a, mc_a = self.policy_a.get_consolidation_params(node, depth)
        ms_b, mc_b = self.policy_b.get_consolidation_params(node, depth)
        return self._combine(ms_a, ms_b), int(self._combine(mc_a, mc_b))

    def on_insert(self, node: TrieNode, x: torch.Tensor, assoc_norm: float) -> None:
        self._total_inserts += 1
        self.policy_a.on_insert(node, x, assoc_norm)
        self.policy_b.on_insert(node, x, assoc_norm)

OctonionTrie

Self-organizing octonionic trie.

Parameters:

Name Type Description Default
associator_threshold float

Maximum associator norm for routing compatibility. Lower values produce more branching (finer discrimination).

0.3
similarity_threshold float

Minimum inner product for rumination acceptance.

0.1
max_depth int

Maximum trie depth.

15
seed int

Random seed for root routing key initialization.

0
dtype dtype

Tensor dtype (float64 recommended for algebraic precision).

float64
policy ThresholdPolicy | None

Pluggable ThresholdPolicy. If None, a GlobalPolicy is created from associator_threshold and similarity_threshold values.

None
Source code in src/octonion/trie.py
class OctonionTrie:
    """Self-organizing octonionic trie.

    Args:
        associator_threshold: Maximum associator norm for routing compatibility.
            Lower values produce more branching (finer discrimination).
        similarity_threshold: Minimum inner product for rumination acceptance.
        max_depth: Maximum trie depth.
        seed: Random seed for root routing key initialization.
        dtype: Tensor dtype (float64 recommended for algebraic precision).
        policy: Pluggable ThresholdPolicy. If None, a GlobalPolicy is created
            from associator_threshold and similarity_threshold values.
    """

    def __init__(
        self,
        associator_threshold: float = 0.3,
        similarity_threshold: float = 0.1,
        max_depth: int = 15,
        seed: int = 0,
        dtype: torch.dtype = torch.float64,
        policy: ThresholdPolicy | None = None,
    ):
        # Default threshold/policy reviewed in Phase T2 (adaptive thresholds) based
        # on cross-benchmark analysis. GlobalPolicy(assoc_threshold=0.3) remains the
        # default: the Phase T2 ThresholdPolicy abstraction added 8 strategy
        # implementations (EMA, MeanStd, Depth, AlgebraicPurity, MetaTrie, Hybrid),
        # but the global baseline with threshold 0.3 provides robust performance
        # across all 5 benchmarks (MNIST, Fashion-MNIST, CIFAR-10, Text 4/20-class).
        # Adaptive strategies are available via the `policy` parameter for tasks
        # where per-node or depth-dependent thresholds are beneficial.
        # See results/T2/analysis/statistical_report.json for full analysis.
        gen = torch.Generator().manual_seed(seed)
        root_key = torch.randn(8, dtype=dtype, generator=gen)
        root_key = root_key / root_key.norm()

        self.root = TrieNode(routing_key=root_key, content=root_key.clone())
        self.max_depth = max_depth
        self.dtype = dtype
        self.n_nodes = 1
        self.total_inserts = 0
        self.rumination_rejections = 0
        self.consolidation_merges = 0

        # Set up threshold policy
        if policy is not None:
            self.policy = policy
        else:
            self.policy = GlobalPolicy(
                assoc_threshold=associator_threshold,
                sim_threshold=similarity_threshold,
            )

    @property
    def assoc_threshold(self) -> float:
        """Backward-compatible property delegating to policy."""
        return self.policy.get_assoc_threshold(self.root, 0)

    @assoc_threshold.setter
    def assoc_threshold(self, value: float) -> None:
        """Backward-compatible setter -- only works with GlobalPolicy."""
        if isinstance(self.policy, GlobalPolicy):
            self.policy.assoc_threshold = value

    @property
    def sim_threshold(self) -> float:
        """Backward-compatible property delegating to policy."""
        return self.policy.get_sim_threshold(self.root, 0)

    @sim_threshold.setter
    def sim_threshold(self, value: float) -> None:
        """Backward-compatible setter -- only works with GlobalPolicy."""
        if isinstance(self.policy, GlobalPolicy):
            self.policy.sim_threshold = value

    def _find_best_child(
        self, node: TrieNode, x: torch.Tensor
    ) -> tuple[int, TrieNode | None, float]:
        """Find the best child for input x at this node.

        Among existing children, selects the one most similar to x
        (highest inner product with routing key), filtered by associator
        compatibility. If no compatible child exists, returns the best
        unoccupied subalgebra slot for new child creation.

        Returns:
            (subalgebra_idx, child_or_None, associator_norm)
        """
        x_oct = Octonion(x)
        node_oct = Octonion(node.routing_key)

        best_compatible: tuple[int, TrieNode, float, float] | None = None

        for sub_idx, child in node.children.items():
            sim = torch.dot(x, child.routing_key).item()
            child_oct = Octonion(child.routing_key)
            assoc = associator(x_oct, child_oct, node_oct)
            assoc_norm = assoc.components.norm().item()

            threshold = self.policy.get_assoc_threshold(child, node.depth, node)
            if assoc_norm < threshold:
                if best_compatible is None or sim > best_compatible[3]:
                    best_compatible = (sub_idx, child, assoc_norm, sim)

        if best_compatible is not None:
            return best_compatible[0], best_compatible[1], best_compatible[2]

        # No compatible child: find best unoccupied subalgebra
        product = octonion_mul(
            node.routing_key.unsqueeze(0), x.unsqueeze(0)
        ).squeeze(0)
        activations = subalgebra_activation(product)
        ranked = activations.argsort(descending=True)

        for sub_idx in ranked:
            idx = sub_idx.item()
            if idx not in node.children:
                return idx, None, float("inf")

        # All 7 occupied, all incompatible: return most similar
        best_sim_idx = max(
            node.children.keys(),
            key=lambda k: torch.dot(x, node.children[k].routing_key).item(),
        )
        threshold = self.policy.get_assoc_threshold(
            node.children[best_sim_idx], node.depth, node
        )
        return best_sim_idx, node.children[best_sim_idx], threshold + 1

    def _ruminate(self, node: TrieNode, x: torch.Tensor) -> bool:
        """Geometric consistency check: is x similar to this node's history?"""
        if len(node.buffer) < 3:
            return True
        sim_thresh = self.policy.get_sim_threshold(node, node.depth)
        key_sim = torch.dot(x, node.routing_key).item()
        if key_sim < sim_thresh * 0.5:
            return False
        sims = [torch.dot(x, buf_x).item() for buf_x, _ in node.buffer]
        return sum(sims) / len(sims) > sim_thresh * 0.3

    def insert(self, x: torch.Tensor, category: int | None = None) -> TrieNode:
        """Insert an octonion into the trie, returning the destination node.

        Policy notification: on_insert is called ONLY on the compatible routing
        path — when a sample naturally routes to an existing node and passes
        rumination. This gives policies clean associator norm statistics from
        real routing events. New child creation, forced descent, and max-depth
        placement do NOT trigger on_insert because:
        - New children have assoc_norm=0 by construction (key=x, alternativity)
        - Forced descent/branching passes wrong assoc_norm (from a different child)
        - These events are structural, not informative for threshold adaptation
        Nodes use base_assoc until min_obs real observations accumulate.
        """
        x = x.to(self.dtype)
        norm = x.norm()
        if norm > 0:
            x = x / norm

        self.total_inserts += 1
        node = self.root
        self._count(node, category)

        for _ in range(self.max_depth):
            if not node.children:
                sub_idx, _, _ = self._find_best_child(node, x)
                return self._create_child(node, x, sub_idx, category)

            sub_idx, child, assoc_norm = self._find_best_child(node, x)

            if child is None:
                return self._create_child(node, x, sub_idx, category)

            threshold = self.policy.get_assoc_threshold(child, node.depth, node)
            if assoc_norm < threshold and self._ruminate(child, x):
                node = child
                self._count(node, category)
                self._compose(node, x)
                node.buffer.append((x.clone(), category))
                self.policy.on_insert(node, x, assoc_norm)
                continue

            if assoc_norm >= threshold:
                self.rumination_rejections += int(assoc_norm < threshold)
                # Find unoccupied slot
                product = octonion_mul(
                    node.routing_key.unsqueeze(0), x.unsqueeze(0)
                ).squeeze(0)
                activations = subalgebra_activation(product)
                for alt in activations.argsort(descending=True):
                    alt_idx = alt.item()
                    if alt_idx not in node.children:
                        return self._create_child(node, x, alt_idx, category)
                # All occupied: descend into best
                node = child
                self._count(node, category)
                continue
            else:
                # Rumination rejected
                self.rumination_rejections += 1
                product = octonion_mul(
                    node.routing_key.unsqueeze(0), x.unsqueeze(0)
                ).squeeze(0)
                activations = subalgebra_activation(product)
                for alt in activations.argsort(descending=True):
                    alt_idx = alt.item()
                    if alt_idx != sub_idx and alt_idx not in node.children:
                        return self._create_child(node, x, alt_idx, category)
                node = child
                self._count(node, category)
                continue

        self._compose(node, x)
        node.buffer.append((x.clone(), category))
        return node

    def query(self, x: torch.Tensor) -> TrieNode:
        """Route x through the trie without modification."""
        x = x.to(self.dtype)
        norm = x.norm()
        if norm > 0:
            x = x / norm

        node = self.root
        for _ in range(self.max_depth):
            if not node.children:
                return node
            _, child, _ = self._find_best_child(node, x)
            if child is None:
                return node
            node = child
        return node

    def consolidate(self) -> None:
        """Merge underused nodes into siblings."""
        self._consolidate_node(self.root)

    def stats(self) -> dict:
        """Compute trie statistics."""
        nodes: list[TrieNode] = []
        leaves: list[TrieNode] = []
        max_depth = 0

        def _walk(n: TrieNode) -> None:
            nonlocal max_depth
            nodes.append(n)
            max_depth = max(max_depth, n.depth)
            if n.is_leaf:
                leaves.append(n)
            for c in n.children.values():
                _walk(c)

        _walk(self.root)
        return {
            "n_nodes": len(nodes),
            "n_leaves": len(leaves),
            "max_depth": max_depth,
            "rumination_rejections": self.rumination_rejections,
            "consolidation_merges": self.consolidation_merges,
        }

    # -- Private helpers --------------------------------------------------

    def _count(self, node: TrieNode, category: int | None) -> None:
        node.insert_count += 1
        if category is not None:
            node.category_counts[category] = node.category_counts.get(category, 0) + 1

    def _create_child(
        self, parent: TrieNode, x: torch.Tensor, sub_idx: int, category: int | None
    ) -> TrieNode:
        child = TrieNode(
            routing_key=x.clone(),
            content=x.clone(),
            subalgebra_idx=sub_idx,
            depth=parent.depth + 1,
            buffer=deque(maxlen=30),
        )
        parent.children[sub_idx] = child
        self.n_nodes += 1
        self._count(child, category)
        child.buffer.append((x.clone(), category))
        return child

    def _compose(self, node: TrieNode, x: torch.Tensor) -> None:
        node.content = octonion_mul(
            node.content.unsqueeze(0), x.unsqueeze(0)
        ).squeeze(0)
        norm = node.content.norm()
        if norm > 0:
            node.content = node.content / norm

    def _consolidate_node(self, node: TrieNode) -> None:
        if not node.children:
            return
        for child in list(node.children.values()):
            self._consolidate_node(child)

        total = sum(c.insert_count for c in node.children.values())
        if total == 0 or len(node.children) < 2:
            return

        min_share, min_count = self.policy.get_consolidation_params(node, node.depth)
        to_remove = [
            idx
            for idx, child in node.children.items()
            if child.insert_count / max(total, 1) < min_share
            and child.insert_count < min_count
        ]
        if not to_remove or len(node.children) - len(to_remove) < 1:
            return

        surviving = {k: v for k, v in node.children.items() if k not in to_remove}
        absorber = surviving[max(surviving, key=lambda k: surviving[k].insert_count)]

        for idx in to_remove:
            removed = node.children.pop(idx)
            for cat, count in removed.category_counts.items():
                absorber.category_counts[cat] = absorber.category_counts.get(cat, 0) + count
            absorber.insert_count += removed.insert_count
            self.n_nodes -= 1
            self.consolidation_merges += 1

assoc_threshold property writable

Backward-compatible property delegating to policy.

sim_threshold property writable

Backward-compatible property delegating to policy.

insert(x, category=None)

Insert an octonion into the trie, returning the destination node.

Policy notification: on_insert is called ONLY on the compatible routing path — when a sample naturally routes to an existing node and passes rumination. This gives policies clean associator norm statistics from real routing events. New child creation, forced descent, and max-depth placement do NOT trigger on_insert because: - New children have assoc_norm=0 by construction (key=x, alternativity) - Forced descent/branching passes wrong assoc_norm (from a different child) - These events are structural, not informative for threshold adaptation Nodes use base_assoc until min_obs real observations accumulate.

Source code in src/octonion/trie.py
def insert(self, x: torch.Tensor, category: int | None = None) -> TrieNode:
    """Insert an octonion into the trie, returning the destination node.

    Policy notification: on_insert is called ONLY on the compatible routing
    path — when a sample naturally routes to an existing node and passes
    rumination. This gives policies clean associator norm statistics from
    real routing events. New child creation, forced descent, and max-depth
    placement do NOT trigger on_insert because:
    - New children have assoc_norm=0 by construction (key=x, alternativity)
    - Forced descent/branching passes wrong assoc_norm (from a different child)
    - These events are structural, not informative for threshold adaptation
    Nodes use base_assoc until min_obs real observations accumulate.
    """
    x = x.to(self.dtype)
    norm = x.norm()
    if norm > 0:
        x = x / norm

    self.total_inserts += 1
    node = self.root
    self._count(node, category)

    for _ in range(self.max_depth):
        if not node.children:
            sub_idx, _, _ = self._find_best_child(node, x)
            return self._create_child(node, x, sub_idx, category)

        sub_idx, child, assoc_norm = self._find_best_child(node, x)

        if child is None:
            return self._create_child(node, x, sub_idx, category)

        threshold = self.policy.get_assoc_threshold(child, node.depth, node)
        if assoc_norm < threshold and self._ruminate(child, x):
            node = child
            self._count(node, category)
            self._compose(node, x)
            node.buffer.append((x.clone(), category))
            self.policy.on_insert(node, x, assoc_norm)
            continue

        if assoc_norm >= threshold:
            self.rumination_rejections += int(assoc_norm < threshold)
            # Find unoccupied slot
            product = octonion_mul(
                node.routing_key.unsqueeze(0), x.unsqueeze(0)
            ).squeeze(0)
            activations = subalgebra_activation(product)
            for alt in activations.argsort(descending=True):
                alt_idx = alt.item()
                if alt_idx not in node.children:
                    return self._create_child(node, x, alt_idx, category)
            # All occupied: descend into best
            node = child
            self._count(node, category)
            continue
        else:
            # Rumination rejected
            self.rumination_rejections += 1
            product = octonion_mul(
                node.routing_key.unsqueeze(0), x.unsqueeze(0)
            ).squeeze(0)
            activations = subalgebra_activation(product)
            for alt in activations.argsort(descending=True):
                alt_idx = alt.item()
                if alt_idx != sub_idx and alt_idx not in node.children:
                    return self._create_child(node, x, alt_idx, category)
            node = child
            self._count(node, category)
            continue

    self._compose(node, x)
    node.buffer.append((x.clone(), category))
    return node

query(x)

Route x through the trie without modification.

Source code in src/octonion/trie.py
def query(self, x: torch.Tensor) -> TrieNode:
    """Route x through the trie without modification."""
    x = x.to(self.dtype)
    norm = x.norm()
    if norm > 0:
        x = x / norm

    node = self.root
    for _ in range(self.max_depth):
        if not node.children:
            return node
        _, child, _ = self._find_best_child(node, x)
        if child is None:
            return node
        node = child
    return node

consolidate()

Merge underused nodes into siblings.

Source code in src/octonion/trie.py
def consolidate(self) -> None:
    """Merge underused nodes into siblings."""
    self._consolidate_node(self.root)

stats()

Compute trie statistics.

Source code in src/octonion/trie.py
def stats(self) -> dict:
    """Compute trie statistics."""
    nodes: list[TrieNode] = []
    leaves: list[TrieNode] = []
    max_depth = 0

    def _walk(n: TrieNode) -> None:
        nonlocal max_depth
        nodes.append(n)
        max_depth = max(max_depth, n.depth)
        if n.is_leaf:
            leaves.append(n)
        for c in n.children.values():
            _walk(c)

    _walk(self.root)
    return {
        "n_nodes": len(nodes),
        "n_leaves": len(leaves),
        "max_depth": max_depth,
        "rumination_rejections": self.rumination_rejections,
        "consolidation_merges": self.consolidation_merges,
    }

subalgebra_activation(x)

Compute activation strength for each of the 7 Fano plane subalgebras.

Each subalgebra is defined by a triple (e_i, e_j, e_k) of imaginary basis units. The activation is the norm of the projection onto those three components.

Parameters:

Name Type Description Default
x Tensor

Octonion tensor of shape [..., 8].

required

Returns:

Type Description
Tensor

Tensor of shape [..., 7] with activation norms.

Source code in src/octonion/trie.py
def subalgebra_activation(x: torch.Tensor) -> torch.Tensor:
    """Compute activation strength for each of the 7 Fano plane subalgebras.

    Each subalgebra is defined by a triple (e_i, e_j, e_k) of imaginary
    basis units. The activation is the norm of the projection onto those
    three components.

    Args:
        x: Octonion tensor of shape [..., 8].

    Returns:
        Tensor of shape [..., 7] with activation norms.
    """
    activations = []
    for triple in FANO_PLANE.triples:
        i, j, k = triple
        components = torch.stack([x[..., i], x[..., j], x[..., k]], dim=-1)
        activations.append(torch.linalg.norm(components, dim=-1))
    return torch.stack(activations, dim=-1)