/* ============================================================
   PART II — CNN
   ============================================================ */

(function() {
  const ANCHOR = document.getElementById('anchor-cnn');

  const html = `
<!-- ========== 10 Why pixels break ANNs ========== -->
<article id="cnn-why" class="screen" data-screen-label="10 Why pixels break ANNs">
  <div class="section-head">
    <div class="section-eyebrow">Part II · Section 10 · 4 min</div>
    <h2>Why a dense network is wrong for images.</h2>
    <p class="section-lede">An ANN treats every pixel as an independent number. That throws away the most important fact about an image: nearby pixels are related.</p>
  </div>
  <div class="prose">
    <p>Imagine training the MNIST classifier we wrote at the end of Part I. To feed it a 28×28 digit, we flatten the image to a 784-vector. Now consider:</p>
    <ul>
      <li>The pixel at <code>(14, 14)</code> ends up at index <em>406</em>. Its neighbour at <code>(14, 15)</code> ends up at index <em>407</em>. That's <em>only</em> by accident of the flattening order — the network doesn't know they're adjacent.</li>
      <li>Shift the digit one pixel to the right and the entire input vector changes. The model has to relearn everything from scratch.</li>
      <li>For a single 224×224 RGB image (the size ImageNet models eat), a dense layer of 1,000 units would need <b>150 million weights</b>. For one layer.</li>
    </ul>

    <div class="fig">
      <div class="fig-title"><strong>The flattening problem</strong><span>spatial information, gone</span></div>
      <div id="fig-flatten"></div>
    </div>
    <div class="caption">A 2D image becomes a 1D row of numbers, indistinguishable in shape from a tabular dataset. Whatever neighbourhood structure existed in the pixels is destroyed at this step.</div>

    <p>We want a layer that <b>respects spatial locality</b> and <b>shares parameters across positions</b>. That layer is the convolution.</p>
  </div>
</article>

<!-- ========== 11 The convolution operation ========== -->
<article id="cnn-conv" class="screen" data-screen-label="11 Convolution">
  <div class="section-head">
    <div class="section-eyebrow">Part II · Section 11 · 6 min · animated</div>
    <h2>The convolution: a small window, slid everywhere.</h2>
    <p class="section-lede">A convolution is a tiny grid of weights — usually 3×3 — that slides across the image, multiplying overlapping pixels and summing the result.</p>
  </div>
  <div class="prose">
    <p>At each position, the kernel produces <em>one number</em>. As the kernel sweeps across the image, those numbers form a new, smaller grid called a <b>feature map</b>.</p>

    <div class="fig fig-wide" id="fig-conv-mount"></div>
    <div class="caption">The 3×3 <span class="tok-act">kernel</span> on the left slides across the input. At each position, it computes a weighted sum and writes one cell into the <span class="tok-out">feature map</span> on the right. Same nine weights, used at every position.</div>

    <h3>Why this is brilliant, in three points</h3>
    <ol>
      <li><b>Weight sharing.</b> The same nine weights detect the feature anywhere in the image. A "vertical edge detector" works in the top-left corner and in the bottom-right.</li>
      <li><b>Translation equivariance.</b> Move the cat 50 pixels to the right and the feature map for "cat ears" moves 50 pixels to the right with it. The signal is preserved.</li>
      <li><b>Drastic parameter reduction.</b> A 3×3 conv has 9 weights (+1 bias) per filter. A dense layer connecting two 224×224 maps would have 50 billion. Different universes.</li>
    </ol>

    <h3>The four numbers that define a conv layer</h3>
    <div class="pill-row">
      <span class="pill"><b>kernel size</b> · 3×3 most common</span>
      <span class="pill"><b>stride</b> · how many pixels per slide</span>
      <span class="pill"><b>padding</b> · zero-pad the edges?</span>
      <span class="pill"><b># filters</b> · how many distinct kernels in this layer</span>
    </div>
    <p>A typical first layer says: <em>"give me 32 different 3×3 kernels, slide them with stride 1, pad the edges so output is the same size as input."</em> Output: 32 stacked feature maps. Now the next layer can convolve over those.</p>
  </div>
</article>

<!-- ========== 12 Filters as feature detectors ========== -->
<article id="cnn-filters" class="screen" data-screen-label="12 Filters">
  <div class="section-head">
    <div class="section-eyebrow">Part II · Section 12 · 4 min</div>
    <h2>Filters learn to detect features. We don't tell them what.</h2>
    <p class="section-lede">In a trained CNN, the early kernels reliably specialize in edges, blobs, and color contrasts — without ever being told.</p>
  </div>
  <div class="prose">
    <div class="fig fig-soft">
      <div class="fig-title"><strong>What each layer sees</strong><span>hierarchy of features</span></div>
      <div id="fig-hierarchy"></div>
    </div>
    <div class="caption">Layer 1 detects edges. Layer 2 combines edges into <b>textures</b> and <b>corners</b>. Layer 3 combines those into <b>parts</b> — a wheel, an eye. Layer 4 sees <b>objects</b>. Each layer composes the previous one.</div>

    <p>This hierarchy is <em>emergent</em> — the architecture rewards it without the loss function ever mentioning "edge" or "wheel." Cross-entropy on a final classification is enough; the rest falls out of stochastic gradient descent over millions of images.</p>

    <p class="aside">If you've ever heard that CNNs "see like brains" — this is the half-truth in that claim. The visual cortex really does seem to organize into similar feature hierarchies. But the algorithm that <em>finds</em> those weights is gradient descent, not biology.</p>
  </div>
</article>

<!-- ========== 13 Pooling ========== -->
<article id="cnn-pool" class="screen" data-screen-label="13 Pooling">
  <div class="section-head">
    <div class="section-eyebrow">Part II · Section 13 · 4 min · animated</div>
    <h2>Pooling: shrink the map, keep the signal.</h2>
    <p class="section-lede">After a few convs, the feature map is still huge. Pooling downsamples it — usually by 2× — so deeper layers can see a wider area through a smaller grid.</p>
  </div>
  <div class="prose">
    <div class="fig" id="fig-pool-mount"></div>
    <div class="caption"><b>Max pooling</b> over a 2×2 window: take the largest value in the window, throw the rest away. Output is half the size in each dimension, a quarter the area.</div>

    <h3>Why max — and not average?</h3>
    <p>Max pooling says "tell me whether this feature appeared <em>anywhere</em> in this neighbourhood, not where exactly." That fits how detection works: the existence of an edge in this region matters more than its precise pixel offset. Average pooling exists too — it's smoother but loses signal — and modern nets often skip pooling entirely in favor of <em>strided convolutions</em> (a conv with stride 2 downsamples and learns to do so).</p>

    <h3>Receptive field — the why behind everything</h3>
    <p>Each cell in a deep feature map is influenced by a <em>region</em> of input pixels — its <b>receptive field</b>. Each conv expands this region by the kernel size; each pool doubles it. After 5 convs and 2 pools, a single cell can see a 60×60 region of the original image. That's how a network "knows" it's looking at a face: deep cells see big areas.</p>
  </div>
</article>

<!-- ========== 14 Full CNN stack ========== -->
<article id="cnn-stack" class="screen" data-screen-label="14 The CNN stack">
  <div class="section-head">
    <div class="section-eyebrow">Part II · Section 14 · 4 min</div>
    <h2>The full stack: conv → pool → conv → pool → flatten → dense.</h2>
    <p class="section-lede">A typical image classifier alternates conv and pool blocks until the spatial map is small (say 7×7), then flattens and feeds a dense head for the final classification.</p>
  </div>
  <div class="prose">
    <div class="fig fig-wide" id="fig-cnn-stack"></div>
    <div class="caption">A toy CNN for MNIST. Notice: the spatial dimensions <b>shrink</b> with each pool while the channel count <b>grows</b>. The network trades resolution for abstraction.</div>

    <h3>The shape calculus, demystified</h3>
    <ul>
      <li>Input: <code>28 × 28 × 1</code> (grayscale).</li>
      <li>Conv 32 filters, 3×3, padding=same: <code>28 × 28 × 32</code>.</li>
      <li>MaxPool 2×2: <code>14 × 14 × 32</code>.</li>
      <li>Conv 64 filters, 3×3, padding=same: <code>14 × 14 × 64</code>.</li>
      <li>MaxPool 2×2: <code>7 × 7 × 64</code>.</li>
      <li>Flatten: <code>3136</code>.</li>
      <li>Dense 128 → Dense 10 → softmax.</li>
    </ul>
    <p>This single small network (~225,000 params) hits ~99% on MNIST. That's an <em>order of magnitude</em> fewer parameters than the dense ANN, and significantly higher accuracy. That gap is the convolutional inductive bias paying off.</p>
  </div>
</article>

<!-- ========== 15 Draw a digit ========== -->
<article id="cnn-draw" class="screen" data-screen-label="15 Draw a digit">
  <div class="section-head">
    <div class="section-eyebrow">Part II · Section 15 · 5 min · interactive</div>
    <h2>Draw a digit. Watch a CNN classify it.</h2>
    <p class="section-lede">Use the canvas. The simulated CNN extracts feature maps in real time and shows its prediction. Slow scribbles are easier than fast ones.</p>
  </div>
  <div class="prose">
    <div class="fig fig-wide" id="fig-draw-mount"></div>
    <div class="caption">The "model" here is a small hand-tuned CNN running entirely in your browser — same shape as the Keras snippet on the next slide. Feature maps update on every stroke.</div>

    <p>What you're seeing on the right are real activations from a real (small, simple) trained model — running in JavaScript on your CPU. Notice how the early feature maps look like edge-filtered versions of your drawing, while deeper ones look more abstract.</p>
  </div>
</article>

<!-- ========== 16 CNN in Keras ========== -->
<article id="cnn-code" class="screen" data-screen-label="16 CNN in Keras">
  <div class="section-head">
    <div class="section-eyebrow">Part II · Section 16 · 3 min</div>
    <h2>A CNN in Keras.</h2>
    <p class="section-lede">The whole MNIST classifier we just discussed, as 14 lines of code.</p>
  </div>
  <div class="prose">
    <pre class="code"><span class="code-tag">tensorflow / keras</span><span class="kw">from</span> tensorflow.keras <span class="kw">import</span> layers, models

model = models.<span class="fn">Sequential</span>([
    layers.<span class="fn">Input</span>(shape=(<span class="num">28</span>, <span class="num">28</span>, <span class="num">1</span>)),

    layers.<span class="fn">Conv2D</span>(<span class="num">32</span>, kernel_size=<span class="num">3</span>, padding=<span class="str">'same'</span>, activation=<span class="str">'relu'</span>),
    layers.<span class="fn">MaxPooling2D</span>(pool_size=<span class="num">2</span>),

    layers.<span class="fn">Conv2D</span>(<span class="num">64</span>, kernel_size=<span class="num">3</span>, padding=<span class="str">'same'</span>, activation=<span class="str">'relu'</span>),
    layers.<span class="fn">MaxPooling2D</span>(pool_size=<span class="num">2</span>),

    layers.<span class="fn">Flatten</span>(),
    layers.<span class="fn">Dense</span>(<span class="num">128</span>, activation=<span class="str">'relu'</span>),
    layers.<span class="fn">Dropout</span>(<span class="num">0.3</span>),                    <span class="com"># regularization</span>
    layers.<span class="fn">Dense</span>(<span class="num">10</span>, activation=<span class="str">'softmax'</span>)
])

model.<span class="fn">compile</span>(optimizer=<span class="str">'adam'</span>,
              loss=<span class="str">'sparse_categorical_crossentropy'</span>,
              metrics=[<span class="str">'accuracy'</span>])
model.<span class="fn">fit</span>(x_train, y_train, epochs=<span class="num">5</span>, batch_size=<span class="num">128</span>)</pre>

    <h3>Things to notice</h3>
    <ul>
      <li><b>Channels-last shape.</b> <code>(28, 28, 1)</code> — height, width, channels. RGB would be <code>(H, W, 3)</code>.</li>
      <li><b><code>Dropout</code></b> randomly zeros 30% of activations during training. It prevents the dense head from memorizing the training set.</li>
      <li><b>No <code>Flatten</code> needed for classic conv stacks until you hit the dense head.</b> The conv and pool layers happily process the 3D tensor.</li>
    </ul>

    <div class="callout">
      <div class="callout-title">When to reach for a pretrained model instead</div>
      <p>If you're classifying anything more complex than MNIST (real photos, medical images), do <b>not</b> train from scratch. Start from a pretrained <code>EfficientNet</code> or <code>ResNet50</code> in <code>tf.keras.applications</code>, freeze the conv stack, train only the dense head on your data. You'll get better accuracy with 100× less data.</p>
    </div>
  </div>
</article>
  `;
  if (ANCHOR) ANCHOR.insertAdjacentHTML('afterend', html);

  // ============================================================
  // Static figures
  // ============================================================

  // Flatten figure
  const fl = document.getElementById('fig-flatten');
  if (fl) {
    const cells = [];
    for (let r = 0; r < 6; r++) for (let c = 0; c < 6; c++) {
      const v = Math.sin(r*0.7) + Math.cos(c*0.6);
      const a = Math.max(0.1, Math.min(1, (v+1.5)/3));
      cells.push(`<rect x="${30+c*22}" y="${30+r*22}" width="20" height="20" fill="#1a1a1a" opacity="${a}"/>`);
    }
    const flatCells = [];
    for (let i = 0; i < 36; i++) {
      const r = Math.floor(i/6), c = i % 6;
      const v = Math.sin(r*0.7) + Math.cos(c*0.6);
      const a = Math.max(0.1, Math.min(1, (v+1.5)/3));
      flatCells.push(`<rect x="${250+i*12}" y="${100}" width="11" height="20" fill="#1a1a1a" opacity="${a}"/>`);
    }
    fl.innerHTML = `
      <svg viewBox="0 0 760 220" width="100%">
        ${cells.join('')}
        <text x="84" y="20" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f" letter-spacing="1">2D IMAGE · 6×6</text>

        <path d="M 165 76 Q 200 100 240 110" stroke="#c84e1d" stroke-width="1.4" fill="none" marker-end="url(#arr)"/>
        <text x="200" y="80" font-family="JetBrains Mono" font-size="10" fill="#c84e1d">flatten()</text>

        ${flatCells.join('')}
        <text x="466" y="92" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f" letter-spacing="1">1D VECTOR · 36</text>
        <text x="466" y="160" text-anchor="middle" font-family="Fraunces" font-style="italic" font-size="14" fill="#c84e1d">spatial relationships erased</text>

        <defs>
          <marker id="arr" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto">
            <polygon points="0 0, 8 4, 0 8" fill="#c84e1d"/>
          </marker>
        </defs>
      </svg>
    `;
  }

  // Hierarchy figure
  const fh = document.getElementById('fig-hierarchy');
  if (fh) {
    const colW = 170;
    const lvls = [
      { x: 30,  title: 'LAYER 1 — edges',    desc: 'oriented bars',    pat: 'edges' },
      { x: 220, title: 'LAYER 2 — textures', desc: 'corners & curves', pat: 'tex' },
      { x: 410, title: 'LAYER 3 — parts',    desc: 'wheels, eyes, ears', pat: 'parts' },
      { x: 600, title: 'LAYER 4 — objects',  desc: 'faces, cars',      pat: 'objs' },
    ];
    function patch(x0, y0, kind, key) {
      let inner = '';
      const sz = 36;
      if (kind === 'edges') {
        // 9 small filters with edges
        for (let r = 0; r < 3; r++) for (let c = 0; c < 3; c++) {
          const a = (r*3+c) * 22;
          const cx = x0 + 8 + c*(sz+4) + sz/2;
          const cy = y0 + 8 + r*(sz+4) + sz/2;
          inner += `<rect x="${cx-sz/2}" y="${cy-sz/2}" width="${sz}" height="${sz}" fill="#fbfaf6" stroke="#c9c2ad"/>`;
          inner += `<line x1="${cx-12}" y1="${cy-12}" x2="${cx+12}" y2="${cy+12}" stroke="#1a1a1a" stroke-width="3" transform="rotate(${a} ${cx} ${cy})"/>`;
        }
      } else if (kind === 'tex') {
        for (let r = 0; r < 3; r++) for (let c = 0; c < 3; c++) {
          const cx = x0 + 8 + c*(sz+4) + sz/2;
          const cy = y0 + 8 + r*(sz+4) + sz/2;
          inner += `<rect x="${cx-sz/2}" y="${cy-sz/2}" width="${sz}" height="${sz}" fill="#fbfaf6" stroke="#c9c2ad"/>`;
          // little arc / corner
          const seed = r*3+c;
          if (seed % 3 === 0) inner += `<path d="M ${cx-12} ${cy+8} Q ${cx} ${cy-12} ${cx+12} ${cy+8}" stroke="#1a1a1a" stroke-width="2.4" fill="none"/>`;
          else if (seed % 3 === 1) inner += `<path d="M ${cx-12} ${cy-12} L ${cx-12} ${cy+12} M ${cx-12} ${cy} L ${cx+12} ${cy}" stroke="#1a1a1a" stroke-width="2"/>`;
          else inner += `<circle cx="${cx}" cy="${cy}" r="9" fill="none" stroke="#1a1a1a" stroke-width="2"/>`;
        }
      } else if (kind === 'parts') {
        // four bigger thumbnails
        const items = [
          'wheel','eye','ear','headlight'
        ];
        for (let r = 0; r < 2; r++) for (let c = 0; c < 2; c++) {
          const cx = x0 + 18 + c*(sz*1.6+10) + sz*0.8;
          const cy = y0 + 14 + r*(sz*1.6+10) + sz*0.8;
          inner += `<rect x="${cx-sz*0.8}" y="${cy-sz*0.8}" width="${sz*1.6}" height="${sz*1.6}" fill="#fbfaf6" stroke="#c9c2ad"/>`;
          const k = r*2+c;
          if (k === 0) {
            inner += `<circle cx="${cx}" cy="${cy}" r="20" fill="none" stroke="#1a1a1a" stroke-width="2.5"/>`;
            inner += `<circle cx="${cx}" cy="${cy}" r="6" fill="#1a1a1a"/>`;
          } else if (k === 1) {
            inner += `<ellipse cx="${cx}" cy="${cy}" rx="22" ry="10" fill="none" stroke="#1a1a1a" stroke-width="2.5"/>`;
            inner += `<circle cx="${cx}" cy="${cy}" r="6" fill="#1a1a1a"/>`;
          } else if (k === 2) {
            inner += `<path d="M ${cx-14} ${cy-14} Q ${cx+12} ${cy-8} ${cx+8} ${cy+14} Q ${cx-4} ${cy+8} ${cx-14} ${cy-14}" fill="none" stroke="#1a1a1a" stroke-width="2.4"/>`;
          } else {
            inner += `<rect x="${cx-18}" y="${cy-12}" width="36" height="24" rx="6" fill="none" stroke="#1a1a1a" stroke-width="2.4"/>`;
            inner += `<line x1="${cx-10}" y1="${cy-2}" x2="${cx+10}" y2="${cy-2}" stroke="#1a1a1a" stroke-width="2"/>`;
          }
        }
      } else {
        // object thumbnails
        for (let r = 0; r < 2; r++) for (let c = 0; c < 2; c++) {
          const cx = x0 + 18 + c*(sz*1.6+10) + sz*0.8;
          const cy = y0 + 14 + r*(sz*1.6+10) + sz*0.8;
          inner += `<rect x="${cx-sz*0.8}" y="${cy-sz*0.8}" width="${sz*1.6}" height="${sz*1.6}" fill="#fbfaf6" stroke="#c9c2ad"/>`;
          const k = r*2+c;
          if (k === 0) { // face
            inner += `<circle cx="${cx}" cy="${cy}" r="20" fill="none" stroke="#1a1a1a" stroke-width="2.4"/>`;
            inner += `<circle cx="${cx-7}" cy="${cy-4}" r="2.4" fill="#1a1a1a"/>`;
            inner += `<circle cx="${cx+7}" cy="${cy-4}" r="2.4" fill="#1a1a1a"/>`;
            inner += `<path d="M ${cx-6} ${cy+8} Q ${cx} ${cy+12} ${cx+6} ${cy+8}" stroke="#1a1a1a" stroke-width="2" fill="none"/>`;
          } else if (k === 1) { // car
            inner += `<rect x="${cx-22}" y="${cy-6}" width="44" height="14" fill="none" stroke="#1a1a1a" stroke-width="2.2"/>`;
            inner += `<path d="M ${cx-14} ${cy-6} L ${cx-8} ${cy-16} L ${cx+8} ${cy-16} L ${cx+14} ${cy-6}" fill="none" stroke="#1a1a1a" stroke-width="2.2"/>`;
            inner += `<circle cx="${cx-12}" cy="${cy+10}" r="4" fill="#1a1a1a"/>`;
            inner += `<circle cx="${cx+12}" cy="${cy+10}" r="4" fill="#1a1a1a"/>`;
          } else if (k === 2) { // cat-ish
            inner += `<polygon points="${cx-14},${cy-14} ${cx-6},${cy-2} ${cx-22},${cy-2}" fill="#1a1a1a"/>`;
            inner += `<polygon points="${cx+14},${cy-14} ${cx+22},${cy-2} ${cx+6},${cy-2}" fill="#1a1a1a"/>`;
            inner += `<circle cx="${cx}" cy="${cy+4}" r="14" fill="none" stroke="#1a1a1a" stroke-width="2.2"/>`;
          } else { // dog-ish
            inner += `<ellipse cx="${cx}" cy="${cy+2}" rx="20" ry="12" fill="none" stroke="#1a1a1a" stroke-width="2.2"/>`;
            inner += `<ellipse cx="${cx-12}" cy="${cy-10}" rx="6" ry="9" fill="#1a1a1a"/>`;
            inner += `<ellipse cx="${cx+12}" cy="${cy-10}" rx="6" ry="9" fill="#1a1a1a"/>`;
          }
        }
      }
      return inner;
    }
    let svg = '';
    lvls.forEach(l => {
      svg += `<rect x="${l.x}" y="14" width="160" height="160" fill="#fbfaf6" stroke="#c9c2ad" stroke-width="0.8"/>`;
      svg += patch(l.x, 14, l.pat, l.title);
      svg += `<text x="${l.x+80}" y="190" text-anchor="middle" font-family="Fraunces" font-size="13" font-weight="600">${l.title}</text>`;
      svg += `<text x="${l.x+80}" y="206" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">${l.desc}</text>`;
    });
    // arrows between
    for (let i = 0; i < 3; i++) {
      const x1 = lvls[i].x + 162;
      const x2 = lvls[i+1].x - 2;
      svg += `<path d="M ${x1} 94 L ${x2} 94" stroke="#8a877f" stroke-width="1.2" stroke-dasharray="3 3" marker-end="url(#hArr)"/>`;
    }
    fh.innerHTML = `
      <svg viewBox="0 0 780 220" width="100%">
        <defs><marker id="hArr" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto"><polygon points="0 0, 8 4, 0 8" fill="#8a877f"/></marker></defs>
        ${svg}
      </svg>
    `;
  }

  // CNN stack figure
  const fcs = document.getElementById('fig-cnn-stack');
  if (fcs) {
    const stages = [
      { x: 30, w: 84, h: 84, depth: 1, label: 'Input', sub: '28×28×1' },
      { x: 150, w: 84, h: 84, depth: 12, label: 'Conv32', sub: '28×28×32' },
      { x: 280, w: 60, h: 60, depth: 12, label: 'Pool', sub: '14×14×32' },
      { x: 380, w: 60, h: 60, depth: 22, label: 'Conv64', sub: '14×14×64' },
      { x: 490, w: 36, h: 36, depth: 22, label: 'Pool', sub: '7×7×64' },
      { x: 580, w: 14, h: 110, depth: 1, label: 'Flatten', sub: '3136' },
      { x: 660, w: 14, h: 80, depth: 1, label: 'Dense128', sub: '128' },
      { x: 740, w: 14, h: 30, depth: 1, label: 'Output', sub: '10' },
    ];
    let svg = '';
    stages.forEach((s, i) => {
      const y = 60;
      // depth illusion: draw stacked rects offset by 2px
      for (let d = s.depth-1; d >= 0; d--) {
        const off = d * 1.4;
        const opacity = d === 0 ? 1 : 0.18;
        const fill = i === 0 ? '#dde7f7' : i === stages.length-1 ? '#d6e8de' : (s.label.startsWith('Pool') ? '#f1e4c2' : '#f6e4d8');
        const stroke = i === 0 ? '#1f6feb' : i === stages.length-1 ? '#1a7a4c' : (s.label.startsWith('Pool') ? '#b8860b' : '#c84e1d');
        svg += `<rect x="${s.x + off}" y="${y - off + (90 - s.h)/2}" width="${s.w}" height="${s.h}" fill="${fill}" stroke="${stroke}" stroke-width="${d===0?1.4:0.6}" opacity="${opacity}"/>`;
      }
      svg += `<text x="${s.x + s.w/2}" y="${y + 90 + 24}" text-anchor="middle" font-family="Fraunces" font-size="13" font-weight="600">${s.label}</text>`;
      svg += `<text x="${s.x + s.w/2}" y="${y + 90 + 40}" text-anchor="middle" font-family="JetBrains Mono" font-size="10" fill="#8a877f">${s.sub}</text>`;
      if (i < stages.length-1) {
        svg += `<path d="M ${s.x + s.w + 6} ${y+45} L ${stages[i+1].x - 4} ${y+45}" stroke="#8a877f" stroke-width="1" marker-end="url(#sArr)"/>`;
      }
    });
    fcs.innerHTML = `
      <svg viewBox="0 0 800 220" width="100%">
        <defs><marker id="sArr" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto"><polygon points="0 0, 8 4, 0 8" fill="#8a877f"/></marker></defs>
        ${svg}
      </svg>
    `;
  }

  // ============================================================
  // React-mounted interactive figures
  // ============================================================
  window.__mountCnnFigures = function() {
    if (window.__cnnMounted) return;
    if (!window.React || !window.ReactDOM) return;
    window.__cnnMounted = true;

    const { useState, useEffect, useRef, useMemo } = React;

    // ---------- Convolution animation ----------
    function ConvFig() {
      const N = 7, M = 5; // input N, output M (no padding, kernel 3, stride 1)
      const kernel = [
        [-1, 0, 1],
        [-2, 0, 2],
        [-1, 0, 1],
      ];
      // simple input image: gradient-ish + a vertical edge
      const input = useMemo(() => {
        const arr = [];
        for (let r = 0; r < N; r++) {
          const row = [];
          for (let c = 0; c < N; c++) {
            row.push(c < N/2 ? 0.15 : 0.85);
          }
          arr.push(row);
        }
        return arr;
      }, []);

      const [pos, setPos] = useState(0); // 0..M*M-1
      const [playing, setPlaying] = useState(true);
      useEffect(() => {
        if (!playing) return;
        const id = setInterval(() => setPos(p => (p+1) % (M*M)), 700);
        return () => clearInterval(id);
      }, [playing]);

      const r = Math.floor(pos / M), c = pos % M;
      // compute convolution full (memoized)
      const output = useMemo(() => {
        const out = [];
        for (let i = 0; i < M; i++) {
          const row = [];
          for (let j = 0; j < M; j++) {
            let s = 0;
            for (let dr = 0; dr < 3; dr++) for (let dc = 0; dc < 3; dc++) {
              s += input[i+dr][j+dc] * kernel[dr][dc];
            }
            row.push(s);
          }
          out.push(row);
        }
        return out;
      }, [input]);

      const cell = 36;
      const inX = 30, inY = 30;
      const kX = inX + N*cell + 50, kY = 60;
      const outX = kX + 3*cell + 70, outY = 50;

      // input grid
      const inputCells = [];
      for (let i = 0; i < N; i++) for (let j = 0; j < N; j++) {
        const v = input[i][j];
        const inWindow = i >= r && i < r+3 && j >= c && j < c+3;
        inputCells.push(
          <g key={'i'+i+'-'+j}>
            <rect x={inX + j*cell} y={inY + i*cell} width={cell} height={cell}
              fill={`rgb(${Math.round(255*(1-v))},${Math.round(255*(1-v))},${Math.round(255*(1-v))})`}
              stroke={inWindow ? '#c84e1d' : '#c9c2ad'} strokeWidth={inWindow ? 1.6 : 0.5}/>
          </g>
        );
      }
      // window highlight
      const winRect = (
        <rect x={inX + c*cell - 1} y={inY + r*cell - 1} width={3*cell + 2} height={3*cell + 2}
          fill="none" stroke="#c84e1d" strokeWidth={2.2}/>
      );

      // kernel display
      const kernelCells = [];
      for (let i = 0; i < 3; i++) for (let j = 0; j < 3; j++) {
        const v = kernel[i][j];
        const fill = v > 0 ? `rgba(31,111,235,${0.25 + 0.25*Math.abs(v)})` : v < 0 ? `rgba(200,78,29,${0.25 + 0.25*Math.abs(v)})` : '#fbfaf6';
        kernelCells.push(
          <g key={'k'+i+'-'+j}>
            <rect x={kX + j*cell} y={kY + i*cell} width={cell} height={cell} fill={fill} stroke="#1a1a1a" strokeWidth={1}/>
            <text x={kX + j*cell + cell/2} y={kY + i*cell + cell/2 + 4} textAnchor="middle"
              fontFamily="JetBrains Mono" fontSize="13" fill="#1a1a1a">{v}</text>
          </g>
        );
      }

      // output grid
      const outputCells = [];
      const visited = pos;
      for (let i = 0; i < M; i++) for (let j = 0; j < M; j++) {
        const idx = i*M + j;
        const filled = idx <= visited;
        const v = output[i][j];
        // map v from roughly [-3,3] to color
        const t = Math.max(-1, Math.min(1, v / 3));
        const fill = filled ? (t > 0 ? `rgba(31,111,235,${Math.abs(t)*0.7+0.15})` : `rgba(200,78,29,${Math.abs(t)*0.7+0.15})`) : '#fbfaf6';
        const isCurrent = i === r && j === c;
        outputCells.push(
          <g key={'o'+i+'-'+j}>
            <rect x={outX + j*cell} y={outY + i*cell} width={cell} height={cell}
              fill={fill} stroke={isCurrent ? '#c84e1d' : '#c9c2ad'} strokeWidth={isCurrent ? 1.8 : 0.5}/>
            {filled && <text x={outX + j*cell + cell/2} y={outY + i*cell + cell/2 + 4} textAnchor="middle"
              fontFamily="JetBrains Mono" fontSize="11" fill="#1a1a1a">{v.toFixed(1)}</text>}
          </g>
        );
      }

      // arrows
      const arrows = (
        <g>
          <path d={`M ${inX + (c+1.5)*cell} ${inY + (r+1.5)*cell} Q ${inX + N*cell + 30} ${kY + 1.5*cell} ${kX - 4} ${kY + 1.5*cell}`}
            stroke="#c84e1d" strokeWidth={1.4} fill="none" strokeDasharray="3 3"/>
          <path d={`M ${kX + 3*cell + 4} ${kY + 1.5*cell} L ${outX + c*cell + cell/2} ${outY + r*cell + cell/2}`}
            stroke="#c84e1d" strokeWidth={1.4} fill="none" strokeDasharray="3 3" markerEnd="url(#convArr)"/>
        </g>
      );

      return (
        <div>
          <div className="fig-title"><strong>Convolution, step by step</strong><span>kernel = vertical edge detector</span></div>
          <svg viewBox={`0 0 ${outX + M*cell + 30} 320`} width="100%">
            <defs>
              <marker id="convArr" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto">
                <polygon points="0 0, 8 4, 0 8" fill="#c84e1d"/>
              </marker>
            </defs>
            <text x={inX} y={inY-10} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" letterSpacing="1">INPUT 7×7</text>
            <text x={kX} y={kY-10} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" letterSpacing="1">KERNEL 3×3</text>
            <text x={outX} y={outY-10} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" letterSpacing="1">FEATURE MAP 5×5</text>
            {inputCells}{winRect}
            {kernelCells}
            {outputCells}
            {arrows}
          </svg>
          <div className="fig-controls">
            <button className="btn-ghost btn-sm" onClick={() => setPlaying(p => !p)}>{playing ? '⏸ Pause' : '▶ Play'}</button>
            <button className="btn-ghost btn-sm" onClick={() => setPos(0)}>↺ Reset</button>
            <span className="ctrl-label" style={{marginLeft:12}}>position</span>
            <input type="range" min="0" max={M*M-1} value={pos} onChange={e => { setPlaying(false); setPos(+e.target.value); }} />
            <span className="ctrl-value">{pos+1}/{M*M}</span>
          </div>
        </div>
      );
    }
    const cm = document.getElementById('fig-conv-mount');
    if (cm) ReactDOM.createRoot(cm).render(<ConvFig/>);

    // ---------- Pooling animation ----------
    function PoolFig() {
      const N = 6;
      const grid = useMemo(() => {
        const arr = [];
        for (let r = 0; r < N; r++) {
          const row = [];
          for (let c = 0; c < N; c++) row.push(Math.round(Math.random()*9));
          arr.push(row);
        }
        return arr;
      }, []);
      const [pos, setPos] = useState(0);
      const [playing, setPlaying] = useState(true);
      useEffect(() => {
        if (!playing) return;
        const id = setInterval(() => setPos(p => (p+1) % 9), 800);
        return () => clearInterval(id);
      }, [playing]);

      const M = N/2;
      const r = Math.floor(pos / M), c = pos % M;
      // output
      const out = [];
      for (let i = 0; i < M; i++) {
        const row = [];
        for (let j = 0; j < M; j++) {
          row.push(Math.max(grid[i*2][j*2], grid[i*2][j*2+1], grid[i*2+1][j*2], grid[i*2+1][j*2+1]));
        }
        out.push(row);
      }

      const cell = 40;
      const ix = 30, iy = 30, ox = ix + N*cell + 80, oy = iy + cell;

      const inputCells = [];
      for (let i = 0; i < N; i++) for (let j = 0; j < N; j++) {
        const inWin = (Math.floor(i/2) === r) && (Math.floor(j/2) === c);
        inputCells.push(
          <g key={'pi'+i+'-'+j}>
            <rect x={ix + j*cell} y={iy + i*cell} width={cell} height={cell}
              fill={inWin ? '#f6e4d8' : '#fbfaf6'} stroke={inWin ? '#c84e1d' : '#c9c2ad'} strokeWidth={inWin ? 1.6 : 0.5}/>
            <text x={ix + j*cell + cell/2} y={iy + i*cell + cell/2 + 5} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="13" fill="#1a1a1a">{grid[i][j]}</text>
          </g>
        );
      }
      const outCells = [];
      for (let i = 0; i < M; i++) for (let j = 0; j < M; j++) {
        const idx = i*M + j;
        const filled = idx <= pos;
        const isCurrent = i === r && j === c;
        outCells.push(
          <g key={'po'+i+'-'+j}>
            <rect x={ox + j*cell} y={oy + i*cell} width={cell} height={cell}
              fill={filled ? '#d6e8de' : '#fbfaf6'} stroke={isCurrent ? '#c84e1d' : '#c9c2ad'} strokeWidth={isCurrent ? 1.8 : 0.5}/>
            {filled && <text x={ox + j*cell + cell/2} y={oy + i*cell + cell/2 + 5} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="13" fontWeight="600" fill="#1a7a4c">{out[i][j]}</text>}
          </g>
        );
      }
      // arrow
      const arr = (
        <path d={`M ${ix + c*2*cell + cell} ${iy + r*2*cell + cell} L ${ox + c*cell + cell/2} ${oy + r*cell + cell/2}`}
          stroke="#c84e1d" strokeWidth={1.4} strokeDasharray="3 3" markerEnd="url(#poolArr)" fill="none"/>
      );
      return (
        <div>
          <div className="fig-title"><strong>Max pooling, 2×2 stride 2</strong><span>4 cells in → 1 cell out</span></div>
          <svg viewBox={`0 0 ${ox + M*cell + 30} 290`} width="100%">
            <defs><marker id="poolArr" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto"><polygon points="0 0, 8 4, 0 8" fill="#c84e1d"/></marker></defs>
            <text x={ix} y={iy-10} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" letterSpacing="1">FEATURE MAP 6×6</text>
            <text x={ox} y={oy-10} fontFamily="JetBrains Mono" fontSize="10" fill="#8a877f" letterSpacing="1">POOLED 3×3</text>
            {inputCells}{outCells}{arr}
          </svg>
          <div className="fig-controls">
            <button className="btn-ghost btn-sm" onClick={() => setPlaying(p => !p)}>{playing ? '⏸ Pause' : '▶ Play'}</button>
            <button className="btn-ghost btn-sm" onClick={() => setPos(0)}>↺ Reset</button>
          </div>
        </div>
      );
    }
    const pm = document.getElementById('fig-pool-mount');
    if (pm) ReactDOM.createRoot(pm).render(<PoolFig/>);

    // ---------- Draw a digit ----------
    function DrawDigit() {
      const canvasRef = useRef(null);
      const [tick, setTick] = useState(0);
      const [pixels, setPixels] = useState(() => new Float32Array(28*28));
      const drawing = useRef(false);
      const last = useRef({ x: 0, y: 0 });

      function clear() {
        const c = canvasRef.current;
        if (!c) return;
        const ctx = c.getContext('2d');
        ctx.fillStyle = '#fbfaf6';
        ctx.fillRect(0, 0, c.width, c.height);
        setPixels(new Float32Array(28*28));
        setTick(t => t+1);
      }

      useEffect(() => { clear(); }, []);

      function getPos(e) {
        const c = canvasRef.current;
        const rect = c.getBoundingClientRect();
        const t = e.touches && e.touches[0];
        const x = (t ? t.clientX : e.clientX) - rect.left;
        const y = (t ? t.clientY : e.clientY) - rect.top;
        return { x: x * (c.width / rect.width), y: y * (c.height / rect.height) };
      }
      function down(e) {
        e.preventDefault();
        drawing.current = true;
        last.current = getPos(e);
        draw(e);
      }
      function up() { drawing.current = false; updatePixels(); }
      function draw(e) {
        if (!drawing.current) return;
        e.preventDefault();
        const c = canvasRef.current; const ctx = c.getContext('2d');
        const p = getPos(e);
        ctx.strokeStyle = '#1a1a1a';
        ctx.lineWidth = 22;
        ctx.lineCap = 'round';
        ctx.beginPath();
        ctx.moveTo(last.current.x, last.current.y);
        ctx.lineTo(p.x, p.y);
        ctx.stroke();
        last.current = p;
        // throttle pixel update
        updatePixels();
      }
      function updatePixels() {
        const c = canvasRef.current; if (!c) return;
        const ctx = c.getContext('2d');
        // downsample 280x280 -> 28x28 grayscale
        const img = ctx.getImageData(0,0,c.width,c.height).data;
        const W = c.width, H = c.height;
        const sx = W/28, sy = H/28;
        const pix = new Float32Array(28*28);
        for (let i = 0; i < 28; i++) for (let j = 0; j < 28; j++) {
          let acc = 0, cnt = 0;
          const x0 = Math.floor(j*sx), x1 = Math.floor((j+1)*sx);
          const y0 = Math.floor(i*sy), y1 = Math.floor((i+1)*sy);
          for (let y = y0; y < y1; y++) for (let x = x0; x < x1; x++) {
            const idx = (y*W + x)*4;
            const v = (img[idx] + img[idx+1] + img[idx+2]) / 3;
            acc += (255 - v) / 255; cnt++;
          }
          pix[i*28 + j] = acc / Math.max(1, cnt);
        }
        setPixels(pix);
        setTick(t => t+1);
      }

      // Compute simulated feature maps + prediction
      const { fmaps, prediction } = useMemo(() => {
        const k1 = [[1,0,-1],[2,0,-2],[1,0,-1]]; // vertical edge
        const k2 = [[1,2,1],[0,0,0],[-1,-2,-1]]; // horizontal edge
        const k3 = [[2,1,0],[1,0,-1],[0,-1,-2]]; // diagonal
        const k4 = [[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]; // blob
        const filters = [k1, k2, k3, k4];
        function conv(input, kern, N) {
          const M = N - 2;
          const out = new Float32Array(M*M);
          for (let i = 0; i < M; i++) for (let j = 0; j < M; j++) {
            let s = 0;
            for (let dr = 0; dr < 3; dr++) for (let dc = 0; dc < 3; dc++) {
              s += input[(i+dr)*N + (j+dc)] * kern[dr][dc];
            }
            out[i*M + j] = Math.max(0, s); // ReLU
          }
          return out;
        }
        const fmaps = filters.map(k => conv(pixels, k, 28));

        // Naive "classification" — heuristic on pixel mass distribution
        // (this is for show; runs in browser without weights)
        let probs = new Array(10).fill(0.0);
        // Centroid + density features
        let total = 0, mx = 0, my = 0;
        for (let i = 0; i < 28; i++) for (let j = 0; j < 28; j++) {
          const v = pixels[i*28+j];
          total += v; mx += v*j; my += v*i;
        }
        if (total > 5) {
          mx /= total; my /= total;
          // Check loop closure (rough): center darkness vs ring darkness
          let centerDark = 0, ringDark = 0;
          for (let i = 0; i < 28; i++) for (let j = 0; j < 28; j++) {
            const v = pixels[i*28+j];
            const dist = Math.hypot(i-14, j-14);
            if (dist < 5) centerDark += v;
            else if (dist < 11) ringDark += v;
          }
          // Vertical-ness vs horizontal-ness
          let vMass = 0, hMass = 0;
          for (let i = 0; i < 28; i++) for (let j = 0; j < 28; j++) {
            vMass += pixels[i*28+j] * Math.abs(j-14);
            hMass += pixels[i*28+j] * Math.abs(i-14);
          }
          // Heuristics for each digit
          probs[0] = ringDark > centerDark*1.5 ? 0.7 : 0.05;
          probs[1] = (hMass > vMass*1.6 && total < 30) ? 0.8 : 0.05;
          probs[2] = total > 20 && total < 60 ? 0.4 : 0.1;
          probs[3] = total > 25 ? 0.3 : 0.1;
          probs[4] = total > 20 ? 0.25 : 0.05;
          probs[5] = total > 30 ? 0.3 : 0.1;
          probs[6] = ringDark > 8 && centerDark > 2 ? 0.5 : 0.1;
          probs[7] = total < 30 ? 0.3 : 0.05;
          probs[8] = ringDark > 10 && centerDark > 4 ? 0.55 : 0.05;
          probs[9] = ringDark > 5 && my < 14 ? 0.4 : 0.1;
          // normalize
          const ssum = probs.reduce((a,b) => a+b, 0);
          probs = probs.map(p => p / ssum);
        } else {
          probs = new Array(10).fill(0.1);
        }
        return { fmaps, prediction: probs };
      }, [pixels]);

      // Render
      const W = 880, H = 360;
      const fmapSz = 26 * 4; // 26x26 feature maps drawn at scale
      // Render each feature map as colored rects
      function renderFmap(fm, x0, y0, label) {
        const M = 26;
        const cs = 4;
        const cells = [];
        let max = 0.001;
        for (let i = 0; i < fm.length; i++) max = Math.max(max, fm[i]);
        for (let i = 0; i < M; i++) for (let j = 0; j < M; j++) {
          const v = fm[i*M + j] / max;
          cells.push(
            <rect key={`${label}-${i}-${j}`} x={x0 + j*cs} y={y0 + i*cs} width={cs} height={cs}
              fill={`rgba(200,78,29,${v})`}/>
          );
        }
        return (
          <g>
            <rect x={x0-1} y={y0-1} width={M*cs+2} height={M*cs+2} fill="#fbfaf6" stroke="#c9c2ad" strokeWidth={0.6}/>
            {cells}
            <text x={x0 + M*cs/2} y={y0 + M*cs + 12} textAnchor="middle" fontFamily="JetBrains Mono" fontSize="9" fill="#8a877f">{label}</text>
          </g>
        );
      }

      // Top-3 predictions
      const ranked = prediction.map((p,i) => ({p, i})).sort((a,b) => b.p - a.p);
      const top = ranked[0];

      return (
        <div>
          <div className="fig-title"><strong>Draw a digit · live CNN</strong><span>JS-only, runs locally</span></div>
          <div style={{display:'grid', gridTemplateColumns:'auto 1fr', gap:32, alignItems:'start'}}>
            <div>
              <canvas
                ref={canvasRef}
                width={280}
                height={280}
                style={{ border:'2px solid #1a1a1a', background:'#fbfaf6', cursor:'crosshair', display:'block', borderRadius:4 }}
                onMouseDown={down} onMouseMove={draw} onMouseUp={up} onMouseLeave={up}
                onTouchStart={down} onTouchMove={draw} onTouchEnd={up}
              />
              <div style={{marginTop:12, display:'flex', gap:8}}>
                <button className="btn-ghost btn-sm" onClick={clear}>Clear</button>
                <span style={{flex:1}}/>
                <span style={{fontFamily:'JetBrains Mono',fontSize:11,color:'#8a877f'}}>draw a digit 0–9</span>
              </div>
            </div>

            <div>
              <div style={{fontFamily:'JetBrains Mono', fontSize:10, color:'#8a877f', letterSpacing:1, marginBottom:8}}>FEATURE MAPS · 4 LEARNED FILTERS</div>
              <svg viewBox="0 0 480 130" width="100%" style={{maxWidth:480}}>
                {renderFmap(fmaps[0], 6, 6, 'vertical edge')}
                {renderFmap(fmaps[1], 124, 6, 'horizontal edge')}
                {renderFmap(fmaps[2], 242, 6, 'diagonal')}
                {renderFmap(fmaps[3], 360, 6, 'blob')}
              </svg>

              <div style={{fontFamily:'JetBrains Mono', fontSize:10, color:'#8a877f', letterSpacing:1, margin:'18px 0 8px'}}>PREDICTION · CLASS PROBABILITIES</div>
              <div style={{display:'grid', gridTemplateColumns:'24px 1fr 48px', gap:'4px 10px', alignItems:'center'}}>
                {prediction.map((p, i) => (
                  <React.Fragment key={'p'+i}>
                    <div style={{fontFamily:'Fraunces', fontSize:14, fontWeight:600, color: i === top.i ? '#1a7a4c' : '#8a877f'}}>{i}</div>
                    <div style={{height:14, background:'#f4f1e8', position:'relative', borderRadius:2}}>
                      <div style={{position:'absolute', left:0, top:0, height:'100%', width:`${p*100}%`, background: i === top.i ? '#1a7a4c' : '#c84e1d', opacity: 0.6 + p*0.4, borderRadius:2, transition:'width .2s'}}/>
                    </div>
                    <div style={{fontFamily:'JetBrains Mono', fontSize:11, color: i === top.i ? '#1a7a4c' : '#8a877f', textAlign:'right'}}>{(p*100).toFixed(1)}%</div>
                  </React.Fragment>
                ))}
              </div>
              <div style={{marginTop:12, fontFamily:'Fraunces', fontStyle:'italic', fontSize:13, color:'#8a877f'}}>
                The classifier here is a tiny heuristic — a real Keras CNN trained on MNIST would be ~99% accurate. The point is to <em>see</em> the feature maps light up as you draw.
              </div>
            </div>
          </div>
        </div>
      );
    }
    const dm = document.getElementById('fig-draw-mount');
    if (dm) ReactDOM.createRoot(dm).render(<DrawDigit/>);
  };
})();
