/* ============================================================
   PART III — RNN
   ============================================================ */

(function() {
  const ANCHOR = document.getElementById('anchor-rnn');

  const html = `
<!-- ========== 17 Why sequence is different ========== -->
<article id="rnn-why" class="screen" data-screen-label="17 Why sequence is different">
  <div class="section-head">
    <div class="section-eyebrow">Part III · Section 17 · 4 min</div>
    <h2>Sequence has a property neither images nor tables have: order.</h2>
    <p class="section-lede">"The cat sat on the mat" and "Mat the on sat cat the" use identical tokens. They mean very different things. The architecture must respect that.</p>
  </div>
  <div class="prose">
    <p>For images, we exploited spatial locality with convolutions. For sequences — text, audio, time series, anything indexed by <em>t</em> — we need an architecture that:</p>
    <ul>
      <li>Processes one token at a time, in order.</li>
      <li>Carries a <b>state</b> that summarizes everything it has seen so far.</li>
      <li>Updates that state with each new token.</li>
      <li>Can produce an output at any (or every) step.</li>
    </ul>

    <div class="fig">
      <div class="fig-title"><strong>The same model, different sequence applications</strong><span>one-to-one is just an ANN</span></div>
      <div id="fig-seq-modes"></div>
    </div>
    <div class="caption">Sequence problems come in shapes: one-to-many (image captioning), many-to-one (sentiment), many-to-many same-length (POS tagging), many-to-many different-length (translation). RNNs handle them all.</div>
  </div>
</article>

<!-- ========== 18 The recurrent cell ========== -->
<article id="rnn-cell" class="screen" data-screen-label="18 The recurrent cell">
  <div class="section-head">
    <div class="section-eyebrow">Part III · Section 18 · 5 min</div>
    <h2>One cell. Used over and over.</h2>
    <p class="section-lede">An RNN is a single dense layer with a twist: it sees not just the current input, but its own previous output.</p>
  </div>
  <div class="prose">
    <div class="fig">
      <div class="fig-title"><strong>The recurrent cell</strong><span>folded view</span></div>
      <div id="fig-rnn-cell"></div>
    </div>
    <div class="caption">At every step, the cell receives the current input <span class="tok-data">x<sub>t</sub></span> AND its own previous hidden state <span class="tok-mem">h<sub>t−1</sub></span>. It outputs a new state <span class="tok-mem">h<sub>t</sub></span>, which becomes the input for step t+1.</div>

    <p>The math is one line, on top of what you already know:</p>
    <div class="eq">
      <span class="var">h</span><span class="sub">t</span> <span class="op">=</span> tanh(<span class="var">W</span><span class="sub">xh</span> <span class="var">x</span><span class="sub">t</span> <span class="op">+</span> <span class="var">W</span><span class="sub">hh</span> <span class="var">h</span><span class="sub">t−1</span> <span class="op">+</span> <span class="var">b</span>)
      <span class="lbl">Same W, every step. The "recurrent" weights are reused.</span>
    </div>

    <p>Like a CNN reuses its kernel across spatial positions, an RNN reuses its weights across time positions. <b>Weight sharing through time</b> is the parallel.</p>

    <div class="callout">
      <div class="callout-title">Hidden state as compressed memory</div>
      <p>Whatever the network "remembers" about the past must fit inside <em>h</em>. If <em>h</em> is a 128-vector, the network has 128 floats to summarize "everything it has seen until now." This bottleneck is why long sequences are hard.</p>
    </div>
  </div>
</article>

<!-- ========== 19 Unrolling through time ========== -->
<article id="rnn-unroll" class="screen" data-screen-label="19 Unrolling">
  <div class="section-head">
    <div class="section-eyebrow">Part III · Section 19 · 5 min · animated</div>
    <h2>Unroll the RNN to see it as a deep network.</h2>
    <p class="section-lede">If you draw out the same cell once per time step, what you get looks suspiciously like a very deep feedforward network — with weight sharing.</p>
  </div>
  <div class="prose">
    <div class="fig fig-wide" id="fig-unroll-mount"></div>
    <div class="caption">Same cell, unrolled across 6 time steps. Watch the hidden state propagate left to right. The <span class="tok-data">input</span> at each step combines with the <span class="tok-mem">previous state</span> to produce the <span class="tok-mem">next state</span> and an <span class="tok-out">output</span>.</div>

    <p>This view makes two things obvious:</p>
    <ol>
      <li><b>It really is a deep network</b> — for a 100-step sequence, you have a 100-layer network, all sharing weights.</li>
      <li><b>Backprop applies as usual</b> — except now it has to flow back through time, not just through layers. We call this <em>backpropagation through time</em>, or <b>BPTT</b>.</li>
    </ol>

    <h3>Two flavors of output</h3>
    <ul>
      <li><b>Last-step output</b> (many-to-one): use only <em>h</em><sub>T</sub>. Good for sentiment, classification.</li>
      <li><b>Every-step output</b> (many-to-many): emit <em>y</em><sub>t</sub> at each step. Good for tagging, language modeling.</li>
    </ul>
    <p>In Keras: <code>SimpleRNN(64, return_sequences=False)</code> vs <code>SimpleRNN(64, return_sequences=True)</code>. One flag, big difference.</p>
  </div>
</article>

<!-- ========== 20 BPTT ========== -->
<article id="rnn-bptt" class="screen" data-screen-label="20 BPTT">
  <div class="section-head">
    <div class="section-eyebrow">Part III · Section 20 · 6 min</div>
    <h2>BPTT, and why long-range memory fails.</h2>
    <p class="section-lede">Backprop through time is mathematically clean and operationally a nightmare — for the same reason as the vanishing gradient.</p>
  </div>
  <div class="prose">
    <p>Recall: backprop multiplies gradients through layers. In an RNN unrolled to <em>T</em> steps, gradients flowing from time <em>T</em> back to time <em>1</em> get multiplied by the <em>same recurrent weight matrix</em> <em>T</em> times.</p>

    <div class="eq">
      ∂L/∂h<span class="sub">1</span> <span class="op">∝</span> (W<span class="sub">hh</span>)<span class="sup">T−1</span>
      <span class="lbl">if eigenvalue &lt; 1: vanish · &gt; 1: explode</span>
    </div>

    <div class="fig" id="fig-bptt-mount"></div>
    <div class="caption">Gradient magnitude as it flows backward through time. With a recurrent weight ≈ 0.5, the signal halves each step. By step 20, it's effectively zero. By step 50, it's literally zero in float32.</div>

    <p>Two failure modes, mirror images:</p>
    <ul>
      <li><b>Vanishing gradients.</b> Eigenvalues &lt; 1 → signal decays exponentially. Network can't learn long-range dependencies. <em>This is the common case</em>.</li>
      <li><b>Exploding gradients.</b> Eigenvalues &gt; 1 → signal blows up to NaN. Easy fix: gradient clipping. Add <code>clipnorm=1.0</code> to your optimizer and forget about it.</li>
    </ul>

    <h3>The famous example</h3>
    <p class="aside">"In France, where I grew up speaking ___." A vanilla RNN, even at sequence length 30, struggles to remember "France" by the time it predicts the blank. The relevant signal has decayed below the noise floor. This is precisely the problem LSTM was invented to solve.</p>

    <p>And so we arrive at the last station of our tour.</p>
  </div>
</article>

<!-- ========== 21 RNN in Keras ========== -->
<article id="rnn-code" class="screen" data-screen-label="21 RNN in Keras">
  <div class="section-head">
    <div class="section-eyebrow">Part III · Section 21 · 3 min</div>
    <h2>An RNN in Keras.</h2>
    <p class="section-lede">A sentiment classifier, IMDB-style.</p>
  </div>
  <div class="prose">
    <pre class="code"><span class="code-tag">tensorflow / keras</span><span class="kw">from</span> tensorflow.keras <span class="kw">import</span> layers, models

VOCAB, EMBED, MAXLEN = <span class="num">10000</span>, <span class="num">64</span>, <span class="num">200</span>

model = models.<span class="fn">Sequential</span>([
    layers.<span class="fn">Input</span>(shape=(MAXLEN,)),
    layers.<span class="fn">Embedding</span>(VOCAB, EMBED),         <span class="com"># word index → vector</span>
    layers.<span class="fn">SimpleRNN</span>(<span class="num">64</span>),                  <span class="com"># return last state only</span>
    layers.<span class="fn">Dense</span>(<span class="num">1</span>, activation=<span class="str">'sigmoid'</span>)   <span class="com"># pos / neg</span>
])

model.<span class="fn">compile</span>(optimizer=tf.keras.optimizers.<span class="fn">Adam</span>(<span class="num">1e-3</span>, clipnorm=<span class="num">1.0</span>),
              loss=<span class="str">'binary_crossentropy'</span>, metrics=[<span class="str">'accuracy'</span>])
model.<span class="fn">fit</span>(x_train, y_train, epochs=<span class="num">5</span>, batch_size=<span class="num">64</span>)</pre>

    <h3>Things to notice</h3>
    <ul>
      <li><b><code>Embedding</code></b> turns an integer word index into a learned dense vector. Without it, RNNs would have to start from one-hot vectors (10,000-D for a 10k vocab — wasteful).</li>
      <li><b><code>SimpleRNN(64)</code></b> creates a 64-dimensional hidden state. Default <code>return_sequences=False</code> means we get only <em>h</em><sub>T</sub>.</li>
      <li><b><code>clipnorm=1.0</code></b> guards against exploding gradients. Cheap insurance.</li>
    </ul>

    <div class="callout">
      <div class="callout-title">Honest disclaimer</div>
      <p>You will probably never train a <code>SimpleRNN</code> in production. They train slowly and forget quickly. We're showing it to make the LSTM, next, feel like the obvious solution it is. In practice, replace <code>SimpleRNN</code> with <code>LSTM</code> in this exact code and you have a much better model.</p>
    </div>
  </div>
</article>
  `;
  if (ANCHOR) ANCHOR.insertAdjacentHTML('afterend', html);

  // ============================================================
  // Static figures
  // ============================================================

  // Sequence modes
  const sm = document.getElementById('fig-seq-modes');
  if (sm) {
    function box(x,y,fill,stroke){return `<rect x="${x-12}" y="${y-12}" width="24" height="24" fill="${fill}" stroke="${stroke}" stroke-width="1.2"/>`;}
    function modes(x0, label, sub, inputs, outputs, hidden) {
      let g = '';
      const stepX = 28;
      const baseY = 90;
      const inY = baseY + 50, outY = baseY - 50, hY = baseY;
      // hidden line
      g += `<line x1="${x0}" y1="${hY}" x2="${x0 + (hidden-1)*stepX}" y2="${hY}" stroke="#b8860b" stroke-width="1.2" stroke-dasharray="2 2"/>`;
      for (let i = 0; i < hidden; i++) {
        const x = x0 + i*stepX;
        g += box(x, hY, '#f1e4c2', '#b8860b');
      }
      // inputs
      inputs.forEach((t, i) => {
        const x = x0 + t*stepX;
        g += box(x, inY, '#dde7f7', '#1f6feb');
        g += `<line x1="${x}" y1="${inY-12}" x2="${x}" y2="${hY+12}" stroke="#1f6feb" stroke-width="1"/>`;
      });
      // outputs
      outputs.forEach((t, i) => {
        const x = x0 + t*stepX;
        g += box(x, outY, '#d6e8de', '#1a7a4c');
        g += `<line x1="${x}" y1="${hY-12}" x2="${x}" y2="${outY+12}" stroke="#1a7a4c" stroke-width="1"/>`;
      });
      g += `<text x="${x0 + (hidden-1)*stepX/2}" y="170" text-anchor="middle" font-family="Fraunces" font-size="13" font-weight="600">${label}</text>`;
      g += `<text x="${x0 + (hidden-1)*stepX/2}" y="186" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">${sub}</text>`;
      return g;
    }
    sm.innerHTML = `
      <svg viewBox="0 0 880 220" width="100%">
        ${modes(40,  'one-to-many',   'image → caption',     [0],         [1,2,3,4],     5)}
        ${modes(220, 'many-to-one',   'tweet → sentiment',   [0,1,2,3,4], [4],           5)}
        ${modes(420, 'many-to-many',  'words → POS tags',    [0,1,2,3,4], [0,1,2,3,4],   5)}
        ${modes(640, 'seq-to-seq',    'EN → FR',             [0,1,2],     [3,4,5,6],     7)}
        <text x="40" y="206" font-family="JetBrains Mono" font-size="9" fill="#1f6feb">▪ input</text>
        <text x="100" y="206" font-family="JetBrains Mono" font-size="9" fill="#b8860b">▪ hidden</text>
        <text x="170" y="206" font-family="JetBrains Mono" font-size="9" fill="#1a7a4c">▪ output</text>
      </svg>
    `;
  }

  // RNN cell folded
  const rc = document.getElementById('fig-rnn-cell');
  if (rc) {
    rc.innerHTML = `
      <svg viewBox="0 0 720 240" width="100%">
        <!-- input -->
        <rect x="120" y="180" width="50" height="34" fill="#dde7f7" stroke="#1f6feb" stroke-width="1.4" rx="3"/>
        <text x="145" y="201" text-anchor="middle" font-family="Fraunces" font-size="14" font-weight="600" fill="#1f6feb">x</text>
        <text x="151" y="206" font-family="Fraunces" font-size="9" fill="#1f6feb">t</text>

        <!-- cell -->
        <rect x="290" y="100" width="140" height="80" fill="#fbfaf6" stroke="#1a1a1a" stroke-width="1.6" rx="6"/>
        <text x="360" y="135" text-anchor="middle" font-family="Fraunces" font-size="20" font-weight="600">RNN cell</text>
        <text x="360" y="160" text-anchor="middle" font-family="JetBrains Mono" font-size="11" fill="#8a877f">tanh(W·x + U·h + b)</text>

        <!-- output -->
        <rect x="550" y="123" width="50" height="34" fill="#f1e4c2" stroke="#b8860b" stroke-width="1.4" rx="3"/>
        <text x="575" y="144" text-anchor="middle" font-family="Fraunces" font-size="14" font-weight="600" fill="#b8860b">h</text>
        <text x="581" y="149" font-family="Fraunces" font-size="9" fill="#b8860b">t</text>

        <!-- recurrence loop -->
        <path d="M 600 140 Q 660 140 660 70 Q 660 30 360 30 Q 200 30 200 100 Q 200 100 290 110"
          stroke="#b8860b" stroke-width="1.6" fill="none" stroke-dasharray="4 3" marker-end="url(#rnnArr)"/>
        <text x="430" y="22" text-anchor="middle" font-family="Fraunces" font-style="italic" font-size="13" fill="#b8860b">h fed back: the recurrence</text>

        <!-- arrows -->
        <line x1="172" y1="195" x2="290" y2="160" stroke="#1f6feb" stroke-width="1.4" marker-end="url(#rnnArr)"/>
        <line x1="430" y1="140" x2="548" y2="140" stroke="#b8860b" stroke-width="1.4" marker-end="url(#rnnArr)"/>

        <text x="145" y="232" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">CURRENT INPUT</text>
        <text x="575" y="178" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">NEW STATE</text>

        <defs><marker id="rnnArr" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto"><polygon points="0 0, 8 4, 0 8" fill="#1a1a1a"/></marker></defs>
      </svg>
    `;
  }

  // ============================================================
  // React figures
  // ============================================================
  window.__mountRnnFigures = function() {
    if (window.__rnnMounted) return;
    if (!window.React || !window.ReactDOM) return;
    window.__rnnMounted = true;
    const { useState, useEffect, useMemo } = React;

    // ---------- Unroll animation ----------
    function UnrollFig() {
      const T = 6;
      const tokens = ['the', 'cat', 'sat', 'on', 'the', 'mat'];
      const [step, setStep] = useState(0);
      const [playing, setPlaying] = useState(true);
      useEffect(() => {
        if (!playing) return;
        const id = setInterval(() => setStep(s => (s+1) % (T+1)), 900);
        return () => clearInterval(id);
      }, [playing]);

      const W = 980, H = 320;
      const stepX = 130;
      const x0 = 80;
      const cellY = 150;
      const inY = 250;
      const outY = 50;

      const cells = [];
      for (let t = 0; t < T; t++) {
        const x = x0 + t*stepX;
        const active = t < step;
        const current = t === step - 1;
        // input
        cells.push(
          <g key={'in'+t}>
            <rect x={x-30} y={inY-18} width={60} height={36} fill={active ? '#dde7f7' : '#fbfaf6'} stroke={active ? '#1f6feb' : '#c9c2ad'} strokeWidth={active ? 1.4 : 0.8} rx={3}/>
            <text x={x} y={inY+5} textAnchor="middle" fontFamily="Fraunces" fontSize="14" fontWeight="600" fill={active ? '#1f6feb' : '#c9c2ad'}>{tokens[t]}</text>
            <text x={x} y={inY+30} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="9" fill="#8a877f">x{t}</text>
          </g>
        );
        // cell
        cells.push(
          <g key={'cell'+t}>
            <rect x={x-36} y={cellY-26} width={72} height={52} fill={current ? '#f6e4d8' : active ? '#fbfaf6' : '#fbfaf6'} stroke={current ? '#c84e1d' : '#1a1a1a'} strokeWidth={current ? 1.8 : 1.2} rx={4}/>
            <text x={x} y={cellY-3} textAnchor="middle" fontFamily="Fraunces" fontSize="12" fontWeight="600">RNN</text>
            <text x={x} y={cellY+13} textAnchor="middle" fontFamily="Fraunces" fontStyle="italic" fontSize="13" fill="#b8860b">h{t+1}</text>
          </g>
        );
        // input -> cell arrow
        cells.push(<line key={'ic'+t} x1={x} y1={inY-18} x2={x} y2={cellY+26} stroke={active ? '#1f6feb' : '#c9c2ad'} strokeWidth={1.2} markerEnd="url(#unArr)"/>);
        // output line
        cells.push(
          <g key={'ot'+t}>
            <line x1={x} y1={cellY-26} x2={x} y2={outY+18} stroke={active ? '#1a7a4c' : '#c9c2ad'} strokeWidth={1.2} markerEnd="url(#unArr)"/>
            <rect x={x-28} y={outY-18} width={56} height={36} fill={active ? '#d6e8de' : '#fbfaf6'} stroke={active ? '#1a7a4c' : '#c9c2ad'} strokeWidth={active ? 1.4 : 0.8} rx={3}/>
            <text x={x} y={outY+5} textAnchor="middle" fontFamily="Fraunces" fontSize="13" fontWeight="600" fill={active ? '#1a7a4c' : '#c9c2ad'}>y{t}</text>
          </g>
        );
        // recurrent arrow to next
        if (t < T-1) {
          const nx = x0 + (t+1)*stepX;
          cells.push(
            <g key={'rec'+t}>
              <line x1={x+36} y1={cellY} x2={nx-36} y2={cellY} stroke={active ? '#b8860b' : '#c9c2ad'} strokeWidth={active ? 1.6 : 0.8} strokeDasharray="3 3" markerEnd="url(#unArrA)"/>
              <text x={(x + nx)/2} y={cellY-10} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="9" fill={active ? '#b8860b' : '#c9c2ad'}>h</text>
            </g>
          );
        }
      }

      // labels
      const labels = (
        <g>
          <text x={20} y={outY+5} fontFamily="JetBrains Mono" fontSize="10" fill="#1a7a4c">y_t</text>
          <text x={20} y={cellY+5} fontFamily="JetBrains Mono" fontSize="10" fill="#1a1a1a">cell</text>
          <text x={20} y={inY+5} fontFamily="JetBrains Mono" fontSize="10" fill="#1f6feb">x_t</text>
        </g>
      );

      return (
        <div>
          <div className="fig-title"><strong>RNN unrolled across time</strong><span>step {step}/{T}</span></div>
          <svg viewBox={`0 0 ${W} ${H}`} width="100%">
            <defs>
              <marker id="unArr" markerWidth="6" markerHeight="6" refX="5" refY="3" orient="auto"><polygon points="0 0, 6 3, 0 6" fill="#1a1a1a"/></marker>
              <marker id="unArrA" markerWidth="6" markerHeight="6" refX="5" refY="3" orient="auto"><polygon points="0 0, 6 3, 0 6" fill="#b8860b"/></marker>
            </defs>
            {labels}
            {cells}
          </svg>
          <div className="fig-controls">
            <button className="btn-ghost btn-sm" onClick={() => setPlaying(p => !p)}>{playing ? '⏸ Pause' : '▶ Play'}</button>
            <button className="btn-ghost btn-sm" onClick={() => setStep(0)}>↺ Reset</button>
            <span className="ctrl-label" style={{marginLeft:12}}>step</span>
            <input type="range" min="0" max={T} value={step} onChange={e => { setPlaying(false); setStep(+e.target.value); }} />
            <span className="ctrl-value">{step}/{T}</span>
          </div>
        </div>
      );
    }
    const um = document.getElementById('fig-unroll-mount');
    if (um) ReactDOM.createRoot(um).render(<UnrollFig/>);

    // ---------- BPTT vanishing ----------
    function BPTTFig() {
      const [w, setW] = useState(0.6);
      const T = 30;
      const data = useMemo(() => {
        const arr = [];
        let g = 1.0;
        for (let i = 0; i < T; i++) {
          arr.push(g);
          g *= w;
        }
        return arr.reverse(); // step 1 leftmost
      }, [w]);

      const W = 760, H = 240;
      const barW = (W - 80) / T;
      const bars = data.map((g, i) => {
        const logG = Math.log10(Math.max(g, 1e-15));
        const h = Math.max(2, Math.min(170, (logG + 12) * 14));
        const finalLayer = i === T-1;
        return (
          <g key={'bb'+i}>
            <rect x={40 + i*barW + 2} y={H - 40 - h} width={barW - 4} height={h}
              fill={finalLayer ? '#1a7a4c' : '#c84e1d'} opacity={0.85}/>
          </g>
        );
      });
      return (
        <div>
          <div className="fig-title"><strong>Gradient flowing back through time</strong><span>step 1 (left) ← step 30 (right, near loss)</span></div>
          <svg viewBox={`0 0 ${W} ${H}`} width="100%">
            <line x1="40" y1={H-40} x2={W-20} y2={H-40} stroke="#c9c2ad" strokeWidth={0.8}/>
            <text x="40" y={H-22} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f">step 1</text>
            <text x={W-20} y={H-22} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" textAnchor="end">step T (output)</text>
            <text x={(W)/2} y={20} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" textAnchor="middle">log gradient magnitude</text>
            {bars}
          </svg>
          <div className="fig-controls">
            <span className="ctrl-label">recurrent weight |W|</span>
            <input type="range" min="0.3" max="1.4" step="0.01" value={w} onChange={e => setW(+e.target.value)}/>
            <span className="ctrl-value">{w.toFixed(2)}</span>
            <span style={{flex:1}}/>
            <span className="ctrl-label" style={{color: w < 1 ? '#c84e1d' : '#1f6feb', fontWeight:600}}>
              {w < 0.95 ? 'vanishing' : w > 1.05 ? 'exploding' : 'critical regime'}
            </span>
          </div>
        </div>
      );
    }
    const bm = document.getElementById('fig-bptt-mount');
    if (bm) ReactDOM.createRoot(bm).render(<BPTTFig/>);
  };
})();
