/* anim-langgraph.jsx — Animated StateGraph execution */

const LangGraphSim = () => {
  // A graph: START -> classify -> [tool_use | respond] -> END
  // We animate: state object updating, current node highlight, edge flow
  const STEPS = [
    { node: "START",    state: { messages: [{ role: "user", content: "What's the weather in Tokyo?" }], next: null } },
    { node: "classify", state: { messages: [{ role: "user", content: "What's the weather in Tokyo?" }], next: "tool_use", intent: "weather_query" } },
    { node: "tool_use", state: { messages: [{ role: "user", content: "What's the weather in Tokyo?" }, { role: "tool", content: "{ temp: 14°C, sky: clear }" }], next: "respond", intent: "weather_query" } },
    { node: "respond",  state: { messages: [{ role: "user", content: "What's the weather in Tokyo?" }, { role: "tool", content: "{ temp: 14°C, sky: clear }" }, { role: "assistant", content: "Tokyo is 14°C and clear." }], next: "END", intent: "weather_query" } },
    { node: "END",      state: { messages: [{ role: "user", content: "..." }, { role: "tool", content: "..." }, { role: "assistant", content: "Tokyo is 14°C and clear." }], next: null, intent: "weather_query" } },
  ];

  const { step, playing, toggle, next, prev, reset, containerRef } = useStepper({ steps: STEPS.length, intervalMs: 2200 });
  const cur = STEPS[step];

  const NODES = {
    START:    { x: 80,  y: 100, r: 22, label: "START", kind: "term" },
    classify: { x: 240, y: 100, r: 38, label: "classify", kind: "node" },
    tool_use: { x: 410, y: 50,  r: 38, label: "tool_use", kind: "node" },
    respond:  { x: 410, y: 150, r: 38, label: "respond",  kind: "node" },
    END:      { x: 580, y: 100, r: 22, label: "END",     kind: "term" },
  };
  const EDGES = [
    { from: "START",    to: "classify", kind: "solid" },
    { from: "classify", to: "tool_use", kind: "cond", label: "if tool_use" },
    { from: "classify", to: "respond",  kind: "cond", label: "else" },
    { from: "tool_use", to: "respond",  kind: "solid" },
    { from: "respond",  to: "END",      kind: "solid" },
  ];

  // Determine which edge is "flowing" between cur step and next step
  const activeEdge = (() => {
    if (step + 1 >= STEPS.length) return null;
    return { from: cur.node, to: STEPS[step+1].node };
  })();

  const isActive = (n) => cur.node === n;
  const isVisited = (n) => STEPS.slice(0, step+1).some(s => s.node === n);

  return (
    <div ref={containerRef}>
      <div style={{ display: "grid", gridTemplateColumns: "1.4fr 1fr", gap: 24 }}>
        {/* Graph */}
        <div>
          <svg viewBox="0 0 660 220" style={{ width: "100%", height: "auto" }}>
            <defs>
              <marker id="arr" viewBox="0 0 10 10" refX="8" refY="5" markerWidth="7" markerHeight="7" orient="auto">
                <path d="M0,0 L10,5 L0,10 z" fill="#6b6b66" />
              </marker>
              <marker id="arr-active" viewBox="0 0 10 10" refX="8" refY="5" markerWidth="7" markerHeight="7" orient="auto">
                <path d="M0,0 L10,5 L0,10 z" fill="#2d5fb8" />
              </marker>
            </defs>

            {/* Edges */}
            {EDGES.map((e, i) => {
              const a = NODES[e.from], b = NODES[e.to];
              const active = activeEdge && activeEdge.from === e.from && activeEdge.to === e.to;
              // Trim to node radius
              const dx = b.x - a.x, dy = b.y - a.y;
              const len = Math.hypot(dx, dy);
              const ux = dx/len, uy = dy/len;
              const x1 = a.x + ux*a.r, y1 = a.y + uy*a.r;
              const x2 = b.x - ux*b.r, y2 = b.y - uy*b.r;
              const mid = { x: (x1+x2)/2, y: (y1+y2)/2 };
              return (
                <g key={i}>
                  <line
                    x1={x1} y1={y1} x2={x2} y2={y2}
                    stroke={active ? "#2d5fb8" : "#c7c7c0"}
                    strokeWidth={active ? 2.2 : 1.2}
                    strokeDasharray={e.kind === "cond" ? "4 3" : "0"}
                    markerEnd={active ? "url(#arr-active)" : "url(#arr)"}
                    style={{ transition: "stroke 0.3s, stroke-width 0.3s" }}
                  />
                  {e.label && (
                    <text x={mid.x + (e.kind === "cond" && a.x === b.x ? 0 : 0)} y={mid.y - 6}
                      textAnchor="middle"
                      fontSize="10" fontFamily="JetBrains Mono, monospace"
                      fill={active ? "#2d5fb8" : "#a3a39d"}
                      style={{ transition: "fill 0.3s" }}>
                      {e.label}
                    </text>
                  )}
                  {/* Token particle traveling along the edge */}
                  {active && (
                    <circle r="4" fill="#c7521c">
                      <animateMotion dur="1.6s" repeatCount="indefinite"
                        path={`M ${x1} ${y1} L ${x2} ${y2}`} />
                    </circle>
                  )}
                </g>
              );
            })}

            {/* Nodes */}
            {Object.entries(NODES).map(([key, n]) => {
              const active = isActive(key);
              const visited = isVisited(key);
              if (n.kind === "term") {
                return (
                  <g key={key}>
                    <circle cx={n.x} cy={n.y} r={n.r}
                      fill={active ? "#1a1a1a" : visited ? "#fbfbfa" : "#fbfbfa"}
                      stroke={active ? "#1a1a1a" : visited ? "#6b6b66" : "#c7c7c0"}
                      strokeWidth={active ? 2 : 1}
                      style={{ transition: "all 0.3s" }} />
                    <text x={n.x} y={n.y+4} textAnchor="middle"
                      fontSize="10" fontFamily="JetBrains Mono, monospace"
                      fontWeight="600"
                      fill={active ? "#fbfbfa" : visited ? "#1a1a1a" : "#a3a39d"}>
                      {n.label}
                    </text>
                  </g>
                );
              }
              return (
                <g key={key}>
                  {active && (
                    <circle cx={n.x} cy={n.y} r={n.r + 8}
                      fill="none" stroke="#2d5fb8" strokeWidth="1.5" opacity="0.4">
                      <animate attributeName="r" from={n.r+4} to={n.r+12} dur="1.4s" repeatCount="indefinite" />
                      <animate attributeName="opacity" from="0.5" to="0" dur="1.4s" repeatCount="indefinite" />
                    </circle>
                  )}
                  <circle cx={n.x} cy={n.y} r={n.r}
                    fill={active ? "#e8f0fe" : visited ? "#fbfbfa" : "#fbfbfa"}
                    stroke={active ? "#2d5fb8" : visited ? "#6b6b66" : "#c7c7c0"}
                    strokeWidth={active ? 2 : 1}
                    style={{ transition: "all 0.3s" }} />
                  <text x={n.x} y={n.y+4} textAnchor="middle"
                    fontSize="13" fontFamily="Inter Tight, sans-serif"
                    fontWeight={active ? "600" : "500"}
                    fill={active ? "#1d3f7d" : visited ? "#1a1a1a" : "#a3a39d"}>
                    {n.label}
                  </text>
                </g>
              );
            })}
          </svg>
        </div>

        {/* State inspector */}
        <div style={{
          background: "#1a1d23", borderRadius: 6, padding: 16,
          fontFamily: "JetBrains Mono, monospace", fontSize: 12,
          color: "#d4d4d4", overflow: "auto", maxHeight: 280,
        }}>
          <div style={{ color: "#6a9955", marginBottom: 8 }}>// state at node "{cur.node}"</div>
          <StateView state={cur.state} />
        </div>
      </div>

      <div style={{ marginTop: 16, display: "flex", gap: 6, flexWrap: "wrap" }}>
        {STEPS.map((s, i) => (
          <div key={i} style={{
            fontSize: 11, fontFamily: "JetBrains Mono, monospace",
            padding: "4px 10px", borderRadius: 3,
            background: i === step ? "#1a1a1a" : i < step ? "#e7e7e2" : "transparent",
            color: i === step ? "#fbfbfa" : i < step ? "#3d3d3a" : "#a3a39d",
            border: `1px solid ${i <= step ? "transparent" : "#e7e7e2"}`,
            transition: "all 0.3s",
          }}>
            {String(i).padStart(2, "0")} {s.node}
          </div>
        ))}
      </div>
    </div>
  );
};

const StateView = ({ state }) => {
  const lines = [];
  lines.push(<div key="o">{"{"}</div>);
  Object.entries(state).forEach(([k, v], i) => {
    if (Array.isArray(v)) {
      lines.push(
        <div key={k} style={{ paddingLeft: 14 }}>
          <span style={{ color: "#9cdcfe" }}>{k}</span>
          <span>: [</span>
        </div>
      );
      v.forEach((item, j) => {
        lines.push(
          <div key={k+j} style={{ paddingLeft: 28, animation: "fadeIn 0.4s" }}>
            <span>{"{ "}</span>
            <span style={{ color: "#9cdcfe" }}>role</span>
            <span>: </span>
            <span style={{ color: "#ce9178" }}>"{item.role}"</span>
            <span>, </span>
            <span style={{ color: "#9cdcfe" }}>content</span>
            <span>: </span>
            <span style={{ color: "#ce9178" }}>"{item.content.length > 40 ? item.content.slice(0,40)+"…" : item.content}"</span>
            <span> {"}"}{j < v.length-1 ? "," : ""}</span>
          </div>
        );
      });
      lines.push(<div key={k+"]"} style={{ paddingLeft: 14 }}>],</div>);
    } else {
      lines.push(
        <div key={k} style={{ paddingLeft: 14 }}>
          <span style={{ color: "#9cdcfe" }}>{k}</span>
          <span>: </span>
          <span style={{ color: v === null ? "#569cd6" : "#ce9178" }}>{v === null ? "null" : `"${v}"`}</span>
          <span>,</span>
        </div>
      );
    }
  });
  lines.push(<div key="c">{"}"}</div>);
  return <>{lines}</>;
};

window.LangGraphSim = LangGraphSim;
