/* ============================================================
   PART I — ANN
   ============================================================
   Renders sections 02–09 into the reader, then mounts React for
   the interactive bits (learning rate playground, backprop animation).
*/

(function() {
  const reader = document.getElementById('reader');
  const ANCHOR = document.getElementById('anchor-ann');

  const html = `
<!-- ========== 02 The artificial neuron ========== -->
<article id="ann-neuron" class="screen" data-screen-label="02 The artificial neuron">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 02 · 4 min</div>
    <h2>The artificial neuron is just a tiny calculator.</h2>
    <p class="section-lede">Stripped of biological metaphor, a neuron is three steps: weight the inputs, add a bias, squash the result through a nonlinearity. That's it.</p>
  </div>

  <div class="prose">
    <p>The unit you'll see drawn as a circle in every diagram does this and only this:</p>

    <div class="eq">
      <span class="var">y</span> <span class="op">=</span> <span class="var">σ</span><span class="op">(</span>
      <span class="var">w</span><span class="sub">1</span><span class="var">x</span><span class="sub">1</span> <span class="op">+</span>
      <span class="var">w</span><span class="sub">2</span><span class="var">x</span><span class="sub">2</span> <span class="op">+</span>
      <span class="op">…</span> <span class="op">+</span>
      <span class="var">w</span><span class="sub">n</span><span class="var">x</span><span class="sub">n</span> <span class="op">+</span>
      <span class="var">b</span><span class="op">)</span>
      <span class="lbl">one neuron, one number out</span>
    </div>

    <div class="fig">
      <div class="fig-title"><strong>Anatomy of a neuron</strong><span>inputs · weights · sum · activation · output</span></div>
      <div id="fig-neuron"></div>
    </div>
    <div class="caption">Inputs flow in from the left, each multiplied by its <span class="tok-data">weight</span>. The neuron sums them, adds a bias, then runs the result through an <span class="tok-act">activation function</span>.</div>

    <h3>Why the activation function matters</h3>
    <p>If you removed the squash — the σ — every layer would just be a matrix multiply. Stack a hundred of them and you still have a single (giant) linear function: a hyperplane. Linear functions can't learn an <em>XOR</em>, never mind a face. The nonlinearity is what gives the network its expressive power.</p>

    <div class="fig fig-soft">
      <div class="fig-title"><strong>Three activations you'll meet daily</strong><span>shape determines behavior</span></div>
      <div id="fig-activations"></div>
    </div>
    <div class="caption"><b>Sigmoid</b> squashes to (0,1) — historic, now mostly retired. <b>Tanh</b> squashes to (−1,1) — zero-centered, useful inside RNNs. <b>ReLU</b> is just <code>max(0,x)</code> — almost free to compute, the modern default.</div>

    <div class="callout">
      <div class="callout-title">Developer takeaway</div>
      <p>When in doubt, use <b>ReLU</b> for hidden layers. Use <b>sigmoid</b> on a single output for binary classification, <b>softmax</b> for multi-class, and <b>nothing</b> (linear) for regression. We'll see why later.</p>
    </div>
  </div>
</article>

<!-- ========== 03 Stacking into a network ========== -->
<article id="ann-network" class="screen" data-screen-label="03 Stacking into a network">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 03 · 3 min</div>
    <h2>Wire a few thousand of them up. That's a network.</h2>
    <p class="section-lede">A layer is a column of neurons that all see the same inputs. A network is layers stacked back-to-back. The output of layer L becomes the input of layer L+1.</p>
  </div>
  <div class="prose">
    <div class="fig fig-wide">
      <div class="fig-title"><strong>From neuron to network</strong><span>same operation, repeated</span></div>
      <div id="fig-network"></div>
    </div>
    <div class="caption">Every neuron in a "dense" layer is connected to every neuron in the next. We call this <b>fully-connected</b> or <b>dense</b>. Each connection has its own weight; each neuron has its own bias.</div>

    <p>For a network with input dimension <em>n</em> and a hidden layer of <em>m</em> neurons, the layer is just a matrix multiply:</p>
    <div class="eq">
      <span class="var">h</span> <span class="op">=</span> <span class="var">σ</span>(<span class="var">W</span><span class="var">x</span> <span class="op">+</span> <span class="var">b</span>)
      <span class="lbl">W is m×n, b is m, h is m. Linear algebra all the way down.</span>
    </div>

    <p>"Deep" learning just means: more than a couple of these. A modern image classifier might stack 50 or 150. A large language model: hundreds. The principle doesn't change.</p>

    <h3>Counting parameters</h3>
    <p>A dense layer from 784 inputs (a flattened 28×28 MNIST image) to 128 hidden units has <b>784 × 128 + 128 = 100,480 parameters</b>. Stack three such layers and you're already past a quarter-million weights — and we haven't even started training.</p>
  </div>
</article>

<!-- ========== 04 Forward pass animated ========== -->
<article id="ann-forward" class="screen" data-screen-label="04 Forward pass">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 04 · 4 min</div>
    <h2>The forward pass: data flows left to right.</h2>
    <p class="section-lede">Press play. Watch a vector turn into a prediction, one layer at a time.</p>
  </div>
  <div class="prose">
    <div class="fig fig-wide" id="fig-forward-mount"></div>
    <div class="caption">The activation pulses are colored by sign and sized by magnitude. After ~6 layers, the input vector has been compressed and reshaped into a 3-class probability distribution.</div>

    <p>Three things to notice as you watch:</p>
    <ol>
      <li><b>Information bottlenecks.</b> Each layer can be wider or narrower than the last. Forcing data through a narrow layer makes the network find a compressed representation — this is how autoencoders work.</li>
      <li><b>The output is just another layer.</b> A softmax output is a dense layer with the softmax activation. Nothing special.</li>
      <li><b>It's all matrix math.</b> Every animation pulse is a multiply-accumulate. A modern GPU does billions per millisecond.</li>
    </ol>
  </div>
</article>

<!-- ========== 05 Loss ========== -->
<article id="ann-loss" class="screen" data-screen-label="05 Loss">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 05 · 4 min</div>
    <h2>Loss measures how wrong we are.</h2>
    <p class="section-lede">Until we put a number on the network's mistakes, there's nothing for gradient descent to descend on.</p>
  </div>
  <div class="prose">
    <p>The <b>loss function</b> takes the network's prediction <em>ŷ</em> and the true label <em>y</em>, and returns a single scalar — bigger is worse. The two you'll meet 95% of the time:</p>

    <div class="split">
      <div>
        <h4>Mean Squared Error · regression</h4>
        <div class="eq" style="font-size:18px;text-align:left;border:none;padding:0;margin:8px 0">L = ½ · ( ŷ − y )<span class="sup">2</span></div>
        <p style="font-size:14px;color:var(--ink-soft);margin:0">Penalizes by squared distance. Predict house prices, temperatures, anything continuous.</p>
      </div>
      <div>
        <h4>Cross-entropy · classification</h4>
        <div class="eq" style="font-size:18px;text-align:left;border:none;padding:0;margin:8px 0">L = − Σ y<span class="sub">i</span> log ŷ<span class="sub">i</span></div>
        <p style="font-size:14px;color:var(--ink-soft);margin:0">Punishes confident-wrong predictions hard. Pair with softmax for multi-class.</p>
      </div>
    </div>

    <div class="fig fig-soft">
      <div class="fig-title"><strong>The loss landscape</strong><span>training is a hike downhill</span></div>
      <div id="fig-landscape"></div>
    </div>
    <div class="caption">Each axis is a weight. The surface is the loss for every weight combination. Training picks a starting point at random and rolls downhill — guided by the gradient.</div>

    <p class="aside">In a real network, the landscape lives in millions of dimensions, not two. We can't visualize it; we just trust the gradient and step.</p>
  </div>
</article>

<!-- ========== 06 Backpropagation ========== -->
<article id="ann-backprop" class="screen" data-screen-label="06 Backpropagation">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 06 · 7 min</div>
    <h2>Backpropagation: assigning blame, layer by layer.</h2>
    <p class="section-lede">Every weight in the network deserves a share of credit (or blame) for the final error. Backprop is the bookkeeping.</p>
  </div>
  <div class="prose">
    <p>The intuition is older than computers: the chain rule. If <em>L</em> depends on <em>z</em> which depends on <em>w</em>, then</p>
    <div class="eq">∂L/∂w = (∂L/∂z) · (∂z/∂w)<span class="lbl">chain rule — calculus in two clauses</span></div>
    <p>Backprop is just the chain rule applied <b>repeatedly</b>, from the output back to the very first weight. The clever bit is that it reuses partial computations: the gradient flowing into layer <em>L</em> is exactly what you need to compute the gradient inside layer <em>L</em>.</p>

    <div class="fig fig-wide" id="fig-backprop-mount"></div>
    <div class="caption">Forward pass in <span class="tok-data">blue</span> carries activations left-to-right. Backward pass in <span class="tok-act">red</span> carries gradients right-to-left. <b>Each weight gets updated by its own contribution to the loss.</b></div>

    <h3>The training loop, in five lines</h3>
    <pre class="code"><span class="code-tag">pseudocode</span><span class="kw">for</span> epoch <span class="kw">in</span> <span class="fn">range</span>(<span class="num">N</span>):
    <span class="kw">for</span> x, y <span class="kw">in</span> dataset:
        y_hat = model(x)              <span class="com"># forward</span>
        loss  = loss_fn(y_hat, y)     <span class="com"># scalar</span>
        grads = loss.backward()       <span class="com"># backprop</span>
        weights -= lr * grads         <span class="com"># step</span></pre>

    <p>That's the entire algorithm. Every framework — TensorFlow, PyTorch, JAX — is fundamentally a high-performance implementation of these five lines.</p>

    <div class="callout">
      <div class="callout-title">Why &ldquo;automatic differentiation&rdquo; matters</div>
      <p>You don't write the gradient yourself. The framework records every operation in the forward pass as a graph, then walks it backward applying the chain rule mechanically. This is why building new architectures is mostly composing forward passes and letting autograd handle the rest.</p>
    </div>
  </div>
</article>

<!-- ========== 07 Learning rate playground ========== -->
<article id="ann-lr" class="screen" data-screen-label="07 Learning rate playground">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 07 · 4 min · interactive</div>
    <h2>The learning rate is the most-tuned hyperparameter you'll ever touch.</h2>
    <p class="section-lede">Too small: training crawls. Too large: it explodes. Just right: a smooth descent.</p>
  </div>
  <div class="prose">
    <p>Drag the slider. Watch the loss curve and the path on the landscape change in real time.</p>

    <div class="fig fig-wide" id="fig-lr-mount"></div>
    <div class="caption">The same network, the same data, the same starting weights. Only the learning rate changes. Notice how a 100× difference in <em>lr</em> turns "learns in 30 steps" into "diverges in 3."</div>

    <h3>Field notes on learning rates</h3>
    <ul>
      <li><b>Start with 1e-3</b> for Adam. <b>1e-2</b> for SGD with momentum. These are reasonable defaults for 90% of problems.</li>
      <li><b>If loss explodes to NaN</b> in the first few steps — your lr is too high. Cut by 10×.</li>
      <li><b>If loss flatlines</b> for hundreds of steps — your lr is probably too low (or your data is broken).</li>
      <li><b>Use a scheduler.</b> Most modern training drops the lr by a factor at certain epochs, or follows a cosine curve. <code>tf.keras.callbacks.ReduceLROnPlateau</code> is a fine starting point.</li>
    </ul>
  </div>
</article>

<!-- ========== 08 Vanishing gradient ========== -->
<article id="ann-vanish" class="screen" data-screen-label="08 Vanishing gradient">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 08 · 5 min</div>
    <h2>The vanishing gradient: why deep used to be impossible.</h2>
    <p class="section-lede">For decades, training networks deeper than a few layers simply didn't work. The reason was a small, sneaky multiplication.</p>
  </div>
  <div class="prose">
    <p>Backprop multiplies gradients through every layer. If each layer's local gradient is <em>less than 1</em> on average — and for sigmoid/tanh, it almost always is — then 20 layers later the signal has been multiplied by something like <code>0.25<span class="sup">20</span> ≈ 10<span class="sup">−12</span></code>. The first layers learn essentially nothing.</p>

    <div class="fig" id="fig-vanish-mount"></div>
    <div class="caption">Each bar is the gradient magnitude at one layer. With sigmoid activations, gradients <em>halve</em> roughly every layer; by layer 10 they're statistically zero.</div>

    <h3>What rescued deep learning</h3>
    <ol>
      <li><b>ReLU activations</b> — gradient is exactly 1 for positive inputs, no shrinkage.</li>
      <li><b>Better initialization</b> (He, Xavier) — keep activations from collapsing or exploding at layer 1.</li>
      <li><b>Batch / Layer normalization</b> — re-center activations between layers so the signal doesn't drift.</li>
      <li><b>Residual connections</b> — let gradients skip layers entirely. ResNet's central trick.</li>
    </ol>
    <p>Together, these turned "deep" from a research curiosity into engineering reality. Keep them in your back pocket — they're the answer when training silently fails to learn.</p>

    <div class="aside">The same issue, in <em>time</em> instead of <em>depth</em>, is the reason we'll need LSTMs in Part IV. Same disease, different vector.</div>
  </div>
</article>

<!-- ========== 09 Keras code ========== -->
<article id="ann-code" class="screen" data-screen-label="09 ANN in Keras">
  <div class="section-head">
    <div class="section-eyebrow">Part I · Section 09 · 3 min</div>
    <h2>An ANN in Keras, in 12 lines.</h2>
    <p class="section-lede">Everything we've covered, expressed as code. Skim now; come back when you're ready to run it.</p>
  </div>
  <div class="prose">
    <pre class="code"><span class="code-tag">tensorflow / keras</span><span class="kw">import</span> tensorflow <span class="kw">as</span> tf
<span class="kw">from</span> tensorflow.keras <span class="kw">import</span> layers, models

<span class="com"># 1. Define the architecture</span>
model = models.<span class="fn">Sequential</span>([
    layers.<span class="fn">Input</span>(shape=(<span class="num">784</span>,)),               <span class="com"># flat MNIST image</span>
    layers.<span class="fn">Dense</span>(<span class="num">128</span>, activation=<span class="str">'relu'</span>),
    layers.<span class="fn">Dense</span>(<span class="num">64</span>,  activation=<span class="str">'relu'</span>),
    layers.<span class="fn">Dense</span>(<span class="num">10</span>,  activation=<span class="str">'softmax'</span>)    <span class="com"># 10 digit classes</span>
])

<span class="com"># 2. Tell it what loss & optimizer to use</span>
model.<span class="fn">compile</span>(
    optimizer=tf.keras.optimizers.<span class="fn">Adam</span>(learning_rate=<span class="num">1e-3</span>),
    loss=<span class="str">'sparse_categorical_crossentropy'</span>,
    metrics=[<span class="str">'accuracy'</span>]
)

<span class="com"># 3. Train</span>
model.<span class="fn">fit</span>(x_train, y_train, epochs=<span class="num">10</span>, batch_size=<span class="num">64</span>,
          validation_data=(x_val, y_val))</pre>

    <p>Map this back to what we just learned:</p>
    <ul>
      <li><code>Sequential</code> ↔ stack of layers, output of one feeds the next.</li>
      <li><code>Dense(128, relu)</code> ↔ a fully-connected layer of 128 neurons with ReLU activation. (The <em>W</em> and <em>b</em> are created automatically.)</li>
      <li><code>Adam</code> ↔ a smarter SGD that adapts the learning rate per-weight.</li>
      <li><code>fit</code> ↔ the five-line training loop, hidden behind one call.</li>
    </ul>

    <div class="callout">
      <div class="callout-title">Practical sanity check</div>
      <p>On MNIST, this exact model reaches ~98% test accuracy in under a minute on a CPU. If yours doesn't — your data isn't normalized to <code>[0,1]</code>, or your labels are one-hot but you used <code>sparse_categorical_crossentropy</code> (or vice versa). It's almost always one of those two.</p>
    </div>
  </div>
</article>
  `;

  if (ANCHOR) {
    ANCHOR.insertAdjacentHTML('afterend', html);
  }

  // ============================================================
  // Static SVG figures (mounted as innerHTML)
  // ============================================================

  // ---- Neuron figure ----
  const figNeuron = document.getElementById('fig-neuron');
  if (figNeuron) {
    figNeuron.innerHTML = `
      <svg viewBox="0 0 720 280" width="100%">
        <!-- inputs -->
        ${[0,1,2,3].map(i => {
          const y = 60 + i * 50;
          return `
            <circle cx="80" cy="${y}" r="14" fill="#dde7f7" stroke="#1f6feb" stroke-width="1.4"/>
            <text x="80" y="${y+4}" text-anchor="middle" font-family="Fraunces" font-size="12" font-weight="600" fill="#1f6feb">x${'₁₂₃₄'[i]}</text>
            <line x1="94" y1="${y}" x2="320" y2="140" stroke="#1f6feb" stroke-width="1.2" opacity="0.5"/>
            <text x="200" y="${y - 24 + (i*4)}" font-family="JetBrains Mono" font-size="11" fill="#1f6feb" opacity="0.85">w${'₁₂₃₄'[i]}</text>
          `;
        }).join('')}
        <text x="40" y="40" font-family="JetBrains Mono" font-size="10" fill="#8a877f" letter-spacing="1">INPUTS</text>
        <text x="200" y="40" font-family="JetBrains Mono" font-size="10" fill="#8a877f" letter-spacing="1">WEIGHTS</text>

        <!-- summation -->
        <circle cx="340" cy="140" r="32" fill="#fbfaf6" stroke="#1a1a1a" stroke-width="1.5"/>
        <text x="340" y="146" text-anchor="middle" font-family="Fraunces" font-size="22" font-weight="500">Σ</text>
        <text x="340" y="190" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f" letter-spacing="1">SUM + BIAS</text>

        <!-- bias -->
        <circle cx="340" cy="60" r="12" fill="#f4f1e8" stroke="#8a877f" stroke-width="1"/>
        <text x="340" y="64" text-anchor="middle" font-family="Fraunces" font-size="11" font-weight="500">b</text>
        <line x1="340" y1="72" x2="340" y2="108" stroke="#8a877f" stroke-width="1.2" stroke-dasharray="2 2"/>

        <!-- activation -->
        <rect x="420" y="110" width="80" height="60" rx="3" fill="#f6e4d8" stroke="#c84e1d" stroke-width="1.5"/>
        <path d="M 432 154 Q 450 154 460 140 Q 470 126 488 126" stroke="#c84e1d" stroke-width="2" fill="none"/>
        <text x="460" y="190" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#c84e1d" letter-spacing="1">σ ACTIVATION</text>
        <line x1="372" y1="140" x2="420" y2="140" stroke="#1a1a1a" stroke-width="1.2"/>

        <!-- output -->
        <line x1="500" y1="140" x2="600" y2="140" stroke="#1a7a4c" stroke-width="1.5"/>
        <circle cx="616" cy="140" r="14" fill="#d6e8de" stroke="#1a7a4c" stroke-width="1.4"/>
        <text x="616" y="144" text-anchor="middle" font-family="Fraunces" font-size="12" font-weight="600" fill="#1a7a4c">y</text>
        <text x="616" y="180" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#1a7a4c" letter-spacing="1">OUTPUT</text>

        <!-- pulse -->
        <circle r="3" fill="#c84e1d">
          <animate attributeName="cx" values="80;340;460;616" dur="3s" repeatCount="indefinite"/>
          <animate attributeName="cy" values="110;140;140;140" dur="3s" repeatCount="indefinite"/>
          <animate attributeName="opacity" values="0;1;1;0" keyTimes="0;0.1;0.9;1" dur="3s" repeatCount="indefinite"/>
        </circle>
      </svg>
    `;
  }

  // ---- Activations figure ----
  const figAct = document.getElementById('fig-activations');
  if (figAct) {
    const W = 720, H = 220;
    function curve(fn) {
      const pts = [];
      for (let i = 0; i <= 60; i++) {
        const x = -6 + (i / 60) * 12;
        const y = fn(x);
        pts.push([x, y]);
      }
      return pts;
    }
    function plot(pts, x0, w, h) {
      const yPad = 10;
      const xs = pts.map(p => p[0]);
      const ys = pts.map(p => p[1]);
      const xmin = Math.min(...xs), xmax = Math.max(...xs);
      const ymin = -1.2, ymax = 2.2;
      const X = (x) => x0 + ((x - xmin) / (xmax - xmin)) * w;
      const Y = (y) => yPad + (1 - (y - ymin) / (ymax - ymin)) * (h - 2*yPad);
      let d = '';
      pts.forEach((p, i) => { d += (i ? ' L ' : 'M ') + X(p[0]).toFixed(1) + ' ' + Y(p[1]).toFixed(1); });
      return { path: d, X, Y, x0, w };
    }
    const sig = plot(curve(x => 1/(1+Math.exp(-x))), 30, 200, H);
    const tah = plot(curve(x => Math.tanh(x)), 270, 200, H);
    const rel = plot(curve(x => Math.max(0,x)), 510, 200, H);

    function axes(ax) {
      const yMid = ax.Y(0);
      const xMid = ax.X(0);
      return `
        <line x1="${ax.x0}" y1="${yMid}" x2="${ax.x0+ax.w}" y2="${yMid}" stroke="#c9c2ad" stroke-width="0.8"/>
        <line x1="${xMid}" y1="10" x2="${xMid}" y2="${H-10}" stroke="#c9c2ad" stroke-width="0.8"/>
      `;
    }
    figAct.innerHTML = `
      <svg viewBox="0 0 ${W} ${H+40}" width="100%">
        ${axes(sig)}
        <path d="${sig.path}" stroke="#1f6feb" stroke-width="2" fill="none"/>
        <text x="${sig.x0+ax_w(sig)/2}" y="${H+20}" text-anchor="middle" font-family="Fraunces" font-size="14" font-weight="600">sigmoid</text>
        <text x="${sig.x0+ax_w(sig)/2}" y="${H+34}" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">σ(x) = 1 / (1 + e⁻ˣ)</text>

        ${axes(tah)}
        <path d="${tah.path}" stroke="#c84e1d" stroke-width="2" fill="none"/>
        <text x="${tah.x0+ax_w(tah)/2}" y="${H+20}" text-anchor="middle" font-family="Fraunces" font-size="14" font-weight="600">tanh</text>
        <text x="${tah.x0+ax_w(tah)/2}" y="${H+34}" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">tanh(x)</text>

        ${axes(rel)}
        <path d="${rel.path}" stroke="#1a7a4c" stroke-width="2" fill="none"/>
        <text x="${rel.x0+ax_w(rel)/2}" y="${H+20}" text-anchor="middle" font-family="Fraunces" font-size="14" font-weight="600">ReLU</text>
        <text x="${rel.x0+ax_w(rel)/2}" y="${H+34}" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">max(0, x)</text>
      </svg>
    `;
    function ax_w(ax){ return ax.w; }
  }

  // ---- Network figure ----
  const figNet = document.getElementById('fig-network');
  if (figNet) {
    const W = 880, H = 320;
    const layers = [
      { x: 100, n: 4, label: 'INPUT (4)' },
      { x: 320, n: 6, label: 'HIDDEN (6)' },
      { x: 540, n: 6, label: 'HIDDEN (6)' },
      { x: 760, n: 3, label: 'OUTPUT (3)' },
    ];
    const positions = layers.map(l => {
      const sp = Math.min(40, (H-80) / Math.max(1, l.n-1));
      const total = sp * (l.n-1);
      return Array.from({length:l.n},(_,i)=>({x:l.x, y:H/2 - total/2 + i*sp}));
    });
    let edges = '';
    for (let i = 0; i < positions.length-1; i++) {
      for (const a of positions[i]) {
        for (const b of positions[i+1]) {
          edges += `<line x1="${a.x}" y1="${a.y}" x2="${b.x}" y2="${b.y}" stroke="#1a1a1a" stroke-width="0.5" opacity="${0.08+Math.random()*0.18}"/>`;
        }
      }
    }
    let nodes = '';
    positions.forEach((layer, li) => {
      layer.forEach(p => {
        const fill = li === 0 ? '#dde7f7' : li === positions.length-1 ? '#d6e8de' : '#fbfaf6';
        const stroke = li === 0 ? '#1f6feb' : li === positions.length-1 ? '#1a7a4c' : '#1a1a1a';
        nodes += `<circle cx="${p.x}" cy="${p.y}" r="9" fill="${fill}" stroke="${stroke}" stroke-width="1.3"/>`;
      });
    });
    let labels = '';
    layers.forEach(l => {
      labels += `<text x="${l.x}" y="${H-15}" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f" letter-spacing="1">${l.label}</text>`;
    });
    figNet.innerHTML = `<svg viewBox="0 0 ${W} ${H}" width="100%">${edges}${nodes}${labels}</svg>`;
  }

  // ---- Loss landscape ----
  const figLand = document.getElementById('fig-landscape');
  if (figLand) {
    const W = 720, H = 320;
    // Draw concentric ovals as a contour map + a descent path
    let contours = '';
    for (let i = 1; i <= 7; i++) {
      const rx = i * 38;
      const ry = i * 24;
      const op = 0.08 + i * 0.05;
      contours += `<ellipse cx="${W/2+30}" cy="${H/2+10}" rx="${rx}" ry="${ry}" fill="none" stroke="#c84e1d" stroke-width="0.8" opacity="${op}" transform="rotate(-18 ${W/2+30} ${H/2+10})"/>`;
    }
    // descent path
    const path = [
      [120, 60], [180, 100], [230, 140], [270, 175], [310, 200], [340, 220], [365, 232], [380, 240], [388, 246]
    ];
    let pathD = path.map((p,i)=>(i?'L':'M')+' '+p[0]+' '+p[1]).join(' ');
    let dots = path.map((p,i)=>`<circle cx="${p[0]}" cy="${p[1]}" r="${i===path.length-1?5:3}" fill="${i===path.length-1?'#1a7a4c':'#1a1a1a'}"/>`).join('');
    figLand.innerHTML = `
      <svg viewBox="0 0 ${W} ${H}" width="100%">
        ${contours}
        <text x="${W/2+30}" y="${H/2+14}" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#c84e1d" letter-spacing="1">MIN</text>
        <path d="${pathD}" stroke="#1a1a1a" stroke-width="1.5" fill="none" stroke-dasharray="3 3"/>
        ${dots}
        <text x="120" y="48" font-family="JetBrains Mono" font-size="10" fill="#8a877f">START · random init</text>
        <text x="20" y="${H-20}" font-family="JetBrains Mono" font-size="10" fill="#8a877f">w₁ →</text>
        <text x="20" y="30" font-family="JetBrains Mono" font-size="10" fill="#8a877f">↑ w₂</text>
      </svg>
    `;
  }

  // ============================================================
  // React-mounted interactive figures (forward, backprop, lr, vanish)
  // ============================================================
  // We rely on React being available — this whole script is type="text/babel".
  // Mounting happens at the end (window load) so DOM nodes exist.
  window.__mountAnnFigures = function() {
    if (window.__annMounted) return;
    if (!window.React || !window.ReactDOM) return;
    window.__annMounted = true;

    const { useState, useEffect, useRef } = React;

    // ---------- Forward pass animation ----------
    function ForwardPassFig() {
      const layers = [4, 6, 6, 4, 3];
      const W = 900, H = 320;
      const xs = layers.map((_,i) => 80 + i * (W-160) / (layers.length-1));
      const positions = layers.map((n, i) => {
        const sp = Math.min(40, (H-80) / Math.max(1, n-1));
        const total = sp * (n-1);
        return Array.from({length:n},(_,j)=>({x:xs[i], y:H/2 - total/2 + j*sp}));
      });
      const [tick, setTick] = useState(0);
      const [playing, setPlaying] = useState(true);
      useEffect(() => {
        if (!playing) return;
        const id = setInterval(() => setTick(t => (t+1) % 600), 30);
        return () => clearInterval(id);
      }, [playing]);

      // moving "front" of activation
      const T = (tick % 200) / 200;  // 0..1
      const segIdx = Math.min(layers.length - 2, Math.floor(T * (layers.length - 1)));
      const segT = (T * (layers.length - 1)) - segIdx;

      // Determine which layers are "lit"
      const litUpTo = Math.floor(T * layers.length);

      const edges = [];
      for (let i = 0; i < positions.length - 1; i++) {
        for (const a of positions[i]) {
          for (const b of positions[i+1]) {
            const lit = i === segIdx;
            edges.push(
              <line key={`${i}-${a.x}-${a.y}-${b.x}-${b.y}`}
                x1={a.x} y1={a.y} x2={b.x} y2={b.y}
                stroke={lit ? '#c84e1d' : '#1a1a1a'}
                strokeWidth={lit ? 1.1 : 0.5}
                opacity={lit ? 0.5 : 0.12}/>
            );
          }
        }
      }

      const nodes = [];
      positions.forEach((layer, li) => {
        const lit = li <= litUpTo;
        layer.forEach((p, ni) => {
          const fill = li === 0 ? '#dde7f7' : li === positions.length-1 ? '#d6e8de' : (lit ? '#f6e4d8' : '#fbfaf6');
          const stroke = li === 0 ? '#1f6feb' : li === positions.length-1 ? '#1a7a4c' : (lit ? '#c84e1d' : '#1a1a1a');
          const r = lit ? 11 : 9;
          nodes.push(<circle key={`n-${li}-${ni}`} cx={p.x} cy={p.y} r={r} fill={fill} stroke={stroke} strokeWidth={1.3}/>);
        });
      });

      // pulse particles between segIdx and segIdx+1
      const pulses = [];
      const A = positions[segIdx];
      const B = positions[segIdx+1];
      if (A && B) {
        for (let i = 0; i < A.length; i++) {
          for (let j = 0; j < B.length; j++) {
            const x = A[i].x + (B[j].x - A[i].x) * segT;
            const y = A[i].y + (B[j].y - A[i].y) * segT;
            pulses.push(<circle key={`p-${i}-${j}`} cx={x} cy={y} r={1.6} fill="#c84e1d" opacity={0.7}/>);
          }
        }
      }

      // labels
      const labels = ['INPUT', 'HIDDEN 1', 'HIDDEN 2', 'HIDDEN 3', 'OUTPUT'];
      const labelEls = layers.map((_,i) => (
        <text key={'lb'+i} x={xs[i]} y={H-12} textAnchor="middle"
          fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" letterSpacing="1">{labels[i]}</text>
      ));

      // probability bars at output
      const outProbs = [0.12, 0.71, 0.17];
      const outX = xs[xs.length-1] + 30;
      const probEls = outProbs.map((p,i) => {
        const y = H/2 - 30 + i * 22;
        return (
          <g key={'pr'+i}>
            <rect x={outX} y={y-8} width={60*p} height={14} fill="#1a7a4c" opacity={0.3+p*0.7}/>
            <text x={outX-8} y={y+3} textAnchor="end" fontFamily="JetBrains Mono" fontSize="11" fill="#1a7a4c">cls {i+1}</text>
            <text x={outX+62} y={y+3} fontFamily="JetBrains Mono" fontSize="11" fill="#1a7a4c">{(p*100).toFixed(0)}%</text>
          </g>
        );
      });

      return (
        <div>
          <div className="fig-title"><strong>Forward pass — animated</strong><span>data flowing left → right</span></div>
          <svg viewBox={`0 0 ${W+120} ${H}`} width="100%">
            {edges}
            {pulses}
            {nodes}
            {labelEls}
            {probEls}
          </svg>
          <div className="fig-controls">
            <button className="btn-ghost btn-sm" onClick={() => setPlaying(p => !p)}>{playing ? '⏸ Pause' : '▶ Play'}</button>
            <button className="btn-ghost btn-sm" onClick={() => setTick(0)}>↺ Reset</button>
            <span className="ctrl-label" style={{marginLeft:12}}>step</span>
            <input type="range" min="0" max="599" value={tick} onChange={e => { setPlaying(false); setTick(+e.target.value); }} />
            <span className="ctrl-value">{tick}</span>
          </div>
        </div>
      );
    }

    const fwdMount = document.getElementById('fig-forward-mount');
    if (fwdMount) ReactDOM.createRoot(fwdMount).render(<ForwardPassFig/>);

    // ---------- Backprop animation ----------
    function BackpropFig() {
      const layers = [3, 5, 5, 3, 2];
      const W = 900, H = 280;
      const xs = layers.map((_,i) => 80 + i * (W-160) / (layers.length-1));
      const positions = layers.map((n, i) => {
        const sp = Math.min(36, (H-80) / Math.max(1, n-1));
        const total = sp * (n-1);
        return Array.from({length:n},(_,j)=>({x:xs[i], y:H/2 - total/2 + j*sp}));
      });
      const [tick, setTick] = useState(0);
      useEffect(() => {
        const id = setInterval(() => setTick(t => (t+1) % 800), 30);
        return () => clearInterval(id);
      }, []);

      // Phase: 0..0.5 forward, 0.5..1 backward
      const T = (tick % 400) / 400;
      const isForward = T < 0.5;
      const phaseT = isForward ? T * 2 : (T - 0.5) * 2;  // 0..1 within phase

      const segCount = layers.length - 1;
      const segIdx = isForward
        ? Math.min(segCount-1, Math.floor(phaseT * segCount))
        : Math.max(0, segCount - 1 - Math.floor(phaseT * segCount));
      const segT = (phaseT * segCount) - Math.floor(phaseT * segCount);

      const edges = [];
      for (let i = 0; i < positions.length - 1; i++) {
        for (const a of positions[i]) {
          for (const b of positions[i+1]) {
            const lit = i === segIdx;
            const color = isForward ? '#1f6feb' : '#c84e1d';
            edges.push(
              <line key={`bp-${i}-${a.x}-${a.y}-${b.x}-${b.y}`}
                x1={a.x} y1={a.y} x2={b.x} y2={b.y}
                stroke={lit ? color : '#1a1a1a'}
                strokeWidth={lit ? 1.1 : 0.4}
                opacity={lit ? 0.55 : 0.1}/>
            );
          }
        }
      }
      const nodes = [];
      positions.forEach((layer, li) => layer.forEach((p, ni) => {
        nodes.push(<circle key={`bn-${li}-${ni}`} cx={p.x} cy={p.y} r={8} fill="#fbfaf6" stroke="#1a1a1a" strokeWidth={1.2}/>);
      }));

      // pulses
      const pulses = [];
      const A = positions[segIdx];
      const B = positions[segIdx+1];
      if (A && B) {
        for (let i = 0; i < A.length; i++) {
          for (let j = 0; j < B.length; j++) {
            const fromA = isForward;
            const x = fromA ? A[i].x + (B[j].x - A[i].x) * segT : B[j].x + (A[i].x - B[j].x) * segT;
            const y = fromA ? A[i].y + (B[j].y - A[i].y) * segT : B[j].y + (A[i].y - B[j].y) * segT;
            pulses.push(<circle key={`bpp-${i}-${j}`} cx={x} cy={y} r={1.8} fill={isForward ? '#1f6feb' : '#c84e1d'}/>);
          }
        }
      }

      // Loss bubble at right
      const lossX = xs[xs.length-1] + 60;
      const lossY = H/2;

      return (
        <div className="fig fig-wide">
          <div className="fig-title"><strong>Forward & backward pass</strong>
            <span style={{color: isForward ? '#1f6feb' : '#c84e1d', fontWeight:600}}>
              {isForward ? '→ FORWARD: activations' : '← BACKWARD: gradients'}
            </span>
          </div>
          <svg viewBox={`0 0 ${W+120} ${H}`} width="100%">
            {edges}
            {pulses}
            {nodes}
            {/* Loss node */}
            <circle cx={lossX} cy={lossY} r={20} fill={isForward ? '#fbfaf6' : '#f6e4d8'} stroke="#c84e1d" strokeWidth={1.4}/>
            <text x={lossX} y={lossY+4} textAnchor="middle" fontFamily="Fraunces" fontSize="14" fontWeight="600" fill="#c84e1d">L</text>
            <text x={lossX} y={lossY+38} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f">LOSS</text>
            {/* connector from output to loss */}
            {positions[positions.length-1].map((p, i) =>
              <line key={'lc'+i} x1={p.x} y1={p.y} x2={lossX} y2={lossY} stroke="#c84e1d" strokeWidth={0.5} opacity={0.3}/>
            )}
          </svg>
        </div>
      );
    }
    const bpMount = document.getElementById('fig-backprop-mount');
    if (bpMount) ReactDOM.createRoot(bpMount).render(<BackpropFig/>);

    // ---------- Learning rate playground ----------
    function LRFig() {
      const [lr, setLr] = useState(0.05);
      const [seed, setSeed] = useState(0);
      // Simulate gradient descent on a 2D bowl: f(x,y) = a*x^2 + b*y^2
      // Path persists; recompute when lr or seed changes.
      const path = React.useMemo(() => {
        const rng = mulberry32(seed * 7919 + 13);
        let x = -1.6 + rng()*0.4, y = 1.4 - rng()*0.4;
        const a = 1.0, b = 4.0;
        const pts = [[x, y]];
        for (let i = 0; i < 60; i++) {
          const gx = 2*a*x, gy = 2*b*y;
          x = x - lr * gx;
          y = y - lr * gy;
          if (Math.abs(x) > 5 || Math.abs(y) > 5) { pts.push([x, y]); break; }
          pts.push([x, y]);
        }
        return pts;
      }, [lr, seed]);

      const losses = path.map(([x,y]) => x*x + 4*y*y);

      // Project to SVG: x in [-2,2], y in [-2,2]
      const W = 880, H = 320;
      // left: 2D landscape, right: loss curve
      const lw = 380, lh = 280;
      const lx0 = 30, ly0 = 20;
      const X = (x) => lx0 + ((x + 2) / 4) * lw;
      const Y = (y) => ly0 + (1 - (y + 2) / 4) * lh;

      const contourEls = [];
      for (let i = 1; i <= 6; i++) {
        const rx = i * 30;
        const ry = i * 15;
        contourEls.push(<ellipse key={'c'+i} cx={X(0)} cy={Y(0)} rx={rx} ry={ry} fill="none" stroke="#c84e1d" strokeWidth="0.8" opacity={0.07 + i*0.05}/>);
      }
      const pathD = path.map(([x,y],i)=>(i?'L':'M')+X(x).toFixed(1)+' '+Y(y).toFixed(1)).join(' ');
      const dots = path.map(([x,y],i) => <circle key={'d'+i} cx={X(x)} cy={Y(y)} r={i === 0 ? 4 : i === path.length-1 ? 5 : 2} fill={i===0?'#1f6feb':i===path.length-1?'#1a7a4c':'#1a1a1a'}/>);

      // loss curve
      const cw = 380, ch = 280;
      const cx0 = 470, cy0 = 20;
      const lossMax = Math.max(...losses, 8);
      const cX = (i) => cx0 + (i / Math.max(1, losses.length-1)) * cw;
      const cY = (v) => cy0 + (1 - Math.min(v, lossMax) / lossMax) * ch;
      const lossPath = losses.map((v,i)=>(i?'L':'M')+cX(i).toFixed(1)+' '+cY(v).toFixed(1)).join(' ');

      // Verdict
      let verdict = '';
      let verdictColor = '#1a7a4c';
      const final = losses[losses.length-1];
      const div = path.some(([x,y]) => Math.abs(x)>3 || Math.abs(y)>3);
      if (div) { verdict = 'diverges — learning rate too high'; verdictColor = '#c84e1d'; }
      else if (final < 0.05) { verdict = 'converged smoothly'; verdictColor = '#1a7a4c'; }
      else if (final < 0.5) { verdict = 'converging slowly'; verdictColor = '#b8860b'; }
      else { verdict = 'too slow — lr too low'; verdictColor = '#b8860b'; }

      return (
        <div>
          <div className="fig-title"><strong>Learning rate playground</strong><span>same problem, different lr</span></div>
          <svg viewBox={`0 0 ${W} ${H}`} width="100%">
            {/* Landscape */}
            <rect x={lx0} y={ly0} width={lw} height={lh} fill="none" stroke="#c9c2ad" strokeWidth="0.8"/>
            {contourEls}
            <text x={X(0)} y={Y(0)+4} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="10" fill="#c84e1d">MIN</text>
            <path d={pathD} stroke="#1a1a1a" strokeWidth={1.2} fill="none"/>
            {dots}
            <text x={lx0} y={ly0-6} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f">LOSS LANDSCAPE</text>

            {/* Loss curve */}
            <rect x={cx0} y={cy0} width={cw} height={ch} fill="none" stroke="#c9c2ad" strokeWidth="0.8"/>
            <path d={lossPath} stroke="#c84e1d" strokeWidth={1.6} fill="none"/>
            {/* horizontal grid */}
            {[0.25, 0.5, 0.75].map(g => (
              <line key={'g'+g} x1={cx0} y1={cy0+g*ch} x2={cx0+cw} y2={cy0+g*ch} stroke="#c9c2ad" strokeWidth="0.4" strokeDasharray="2 3"/>
            ))}
            <text x={cx0} y={cy0-6} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f">LOSS OVER STEPS</text>
            <text x={cx0+cw} y={cy0+ch+18} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" textAnchor="end">step {losses.length-1}</text>
          </svg>
          <div className="fig-controls">
            <span className="ctrl-label">learning rate</span>
            <input type="range" min="0.001" max="0.6" step="0.001" value={lr} onChange={e => setLr(+e.target.value)} />
            <span className="ctrl-value">{lr.toFixed(3)}</span>
            <button className="btn-ghost btn-sm" onClick={() => setSeed(s => s+1)}>↻ new start</button>
            <span style={{flex:1}}/>
            <span className="ctrl-label" style={{color:verdictColor, fontWeight:600}}>{verdict}</span>
          </div>
        </div>
      );
    }
    function mulberry32(a) {
      return function() {
        a |= 0; a = a + 0x6D2B79F5 | 0;
        let t = a;
        t = Math.imul(t ^ t >>> 15, t | 1);
        t ^= t + Math.imul(t ^ t >>> 7, t | 61);
        return ((t ^ t >>> 14) >>> 0) / 4294967296;
      };
    }
    const lrMount = document.getElementById('fig-lr-mount');
    if (lrMount) ReactDOM.createRoot(lrMount).render(<LRFig/>);

    // ---------- Vanishing gradient figure ----------
    function VanishFig() {
      const [activation, setActivation] = useState('sigmoid');
      const layers = 12;
      // simulated gradient magnitudes
      const data = React.useMemo(() => {
        const arr = [];
        let g = 1.0;
        for (let i = 0; i < layers; i++) {
          arr.push(g);
          if (activation === 'sigmoid') g *= 0.45 + Math.random()*0.1;
          else if (activation === 'tanh') g *= 0.55 + Math.random()*0.15;
          else g *= 0.92 + Math.random()*0.10; // ReLU loses very little
        }
        return arr.reverse();  // layer 1 (deepest from output) on left
      }, [activation]);

      const W = 760, H = 240;
      const barW = (W - 80) / layers;
      const bars = data.map((g, i) => {
        const h = Math.max(2, Math.min(180, Math.log10(g+1e-12) * 30 + 180));
        const color = activation === 'relu' ? '#1a7a4c' : activation === 'tanh' ? '#b8860b' : '#c84e1d';
        return (
          <g key={'b'+i}>
            <rect x={40 + i*barW + 4} y={H - 40 - h} width={barW - 8} height={h} fill={color} opacity={0.85}/>
            <text x={40 + i*barW + barW/2} y={H - 26} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f">L{i+1}</text>
            <text x={40 + i*barW + barW/2} y={H - 40 - h - 6} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="9" fill="#8a877f">{g < 1e-4 ? g.toExponential(0) : g.toFixed(3)}</text>
          </g>
        );
      });

      return (
        <div>
          <div className="fig-title"><strong>Gradient magnitude per layer</strong><span>L1 = first layer (furthest from loss)</span></div>
          <svg viewBox={`0 0 ${W} ${H}`} width="100%">
            <line x1="40" y1={H-40} x2={W-20} y2={H-40} stroke="#c9c2ad" strokeWidth={0.8}/>
            {bars}
          </svg>
          <div className="fig-controls">
            <span className="ctrl-label">activation</span>
            {['sigmoid', 'tanh', 'relu'].map(a => (
              <button key={a} className={"btn-ghost btn-sm"} style={{
                background: activation===a ? '#1a1a1a' : 'transparent',
                color: activation===a ? '#fbfaf6' : '#1a1a1a',
                borderColor: '#1a1a1a',
              }} onClick={() => setActivation(a)}>{a}</button>
            ))}
          </div>
        </div>
      );
    }
    const vanMount = document.getElementById('fig-vanish-mount');
    if (vanMount) ReactDOM.createRoot(vanMount).render(<VanishFig/>);
  };
})();
