#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
BENCHMARK="${BENCHMARK:-$ROOT_DIR/vendor/rtx6kpro/benchmarks/inference-throughput/benchmark_sglang.py}"
SERVE_SCRIPT="${SERVE_SCRIPT:-$ROOT_DIR/scripts/serve-gemma4-sglang.sh}"
SMOKE_SCRIPT="${SMOKE_SCRIPT:-$ROOT_DIR/scripts/smoke_test_gemma4_openai.py}"
RESULTS_DIR="${RESULTS_DIR:-$ROOT_DIR/results/gemma4-mtp-sglang}"
LOG_DIR="${LOG_DIR:-$RESULTS_DIR/logs}"

VARIANT="${VARIANT:-26B-A4B}"
MODEL="${MODEL:-google/gemma-4-26B-A4B-it}"
HOST="${HOST:-127.0.0.1}"
PORT="${PORT:-30000}"
TP_SIZE="${TP_SIZE:-1}"
MEM_FRACTION_STATIC="${MEM_FRACTION_STATIC:-0.85}"
CONTEXT_LENGTH="${CONTEXT_LENGTH:-262144}"
MAX_RUNNING_REQUESTS="${MAX_RUNNING_REQUESTS:-48}"
CONCURRENCY="${CONCURRENCY:-1,4,8}"
CONTEXTS="${CONTEXTS:-0,8192,32768}"
DURATION="${DURATION:-30}"
MAX_TOKENS="${MAX_TOKENS:-2048}"
IGNORE_EOS="${IGNORE_EOS:-1}"
MIN_TOKENS="${MIN_TOKENS:-$MAX_TOKENS}"
MAX_TOTAL_TOKENS="${MAX_TOTAL_TOKENS:-0}"
WAIT_TIMEOUT="${WAIT_TIMEOUT:-2400}"

mkdir -p "$RESULTS_DIR" "$LOG_DIR"

server_pid=""

cleanup() {
  if [[ -n "$server_pid" ]] && kill -0 "$server_pid" 2>/dev/null; then
    kill "$server_pid" 2>/dev/null || true
    wait "$server_pid" 2>/dev/null || true
  fi
}
trap cleanup EXIT

wait_for_server() {
  local base_url="http://${HOST}:${PORT}"
  local deadline=$((SECONDS + WAIT_TIMEOUT))
  until curl -fsS "${base_url}/v1/models" >/dev/null 2>&1; do
    if (( SECONDS > deadline )); then
      echo "Timed out waiting for ${base_url}/v1/models" >&2
      return 1
    fi
    if [[ -n "$server_pid" ]] && ! kill -0 "$server_pid" 2>/dev/null; then
      echo "SGLang server exited before readiness. Last log lines:" >&2
      tail -n 120 "$current_log" >&2 || true
      return 1
    fi
    sleep 5
  done
}

run_case() {
  local label="$1"
  local enable_mtp="$2"
  current_log="$LOG_DIR/${label}.log"

  echo
  echo "=== Starting ${label} (ENABLE_MTP=${enable_mtp}) ==="
  ENABLE_MTP="$enable_mtp" VARIANT="$VARIANT" MODEL_PATH="$MODEL" SERVED_MODEL_NAME="$MODEL" \
    HOST="0.0.0.0" PORT="$PORT" TP_SIZE="$TP_SIZE" \
    MEM_FRACTION_STATIC="$MEM_FRACTION_STATIC" CONTEXT_LENGTH="$CONTEXT_LENGTH" \
    MAX_RUNNING_REQUESTS="$MAX_RUNNING_REQUESTS" \
    "$SERVE_SCRIPT" >"$current_log" 2>&1 &
  server_pid="$!"

  wait_for_server

  python "$SMOKE_SCRIPT" \
    --base-url "http://${HOST}:${PORT}/v1" \
    --model "$MODEL" | tee "$RESULTS_DIR/${label}.smoke.txt"

  bench_args=()
  if [[ "$IGNORE_EOS" == "1" ]]; then
    bench_args+=(--ignore-eos)
  fi
  if (( MIN_TOKENS > 0 )); then
    bench_args+=(--min-tokens "$MIN_TOKENS")
  fi

  python "$BENCHMARK" \
    --host "$HOST" \
    --port "$PORT" \
    --model "$MODEL" \
    --concurrency "$CONCURRENCY" \
    --contexts "$CONTEXTS" \
    --duration "$DURATION" \
    --max-tokens "$MAX_TOKENS" \
    "${bench_args[@]}" \
    --max-total-tokens "$MAX_TOTAL_TOKENS" \
    --live-state "$RESULTS_DIR/${label}.live.json" \
    --output "$RESULTS_DIR/${label}.json"

  cleanup
  server_pid=""
  sleep 10
}

run_case "baseline" 0
run_case "mtp_topk1" 1

python - "$RESULTS_DIR/baseline.json" "$RESULTS_DIR/mtp_topk1.json" "$RESULTS_DIR/summary.json" <<'PY'
import json
import sys

baseline_path, mtp_path, out_path = sys.argv[1:4]
with open(baseline_path) as f:
    baseline = json.load(f)
with open(mtp_path) as f:
    mtp = json.load(f)

def cells(doc):
    return {
        (r["context_tokens"], r["concurrency"]): r
        for r in doc.get("results", [])
    }

b = cells(baseline)
m = cells(mtp)
rows = []
for key in sorted(set(b) & set(m)):
    br = b[key]["aggregate_tps"]
    mr = m[key]["aggregate_tps"]
    speedup = (mr / br) if br > 0 else 0.0
    rows.append({
        "context_tokens": key[0],
        "concurrency": key[1],
        "baseline_tps": round(br, 3),
        "mtp_tps": round(mr, 3),
        "speedup": round(speedup, 3),
        "baseline_ttft": round(b[key]["ttft_avg"], 3),
        "mtp_ttft": round(m[key]["ttft_avg"], 3),
        "mtp_accept_rate": round(m[key].get("server_spec_accept_rate", 0.0), 4),
        "baseline_tokens": b[key].get("total_tokens", 0),
        "mtp_tokens": m[key].get("total_tokens", 0),
        "baseline_wall_time": round(b[key].get("wall_time", 0.0), 3),
        "mtp_wall_time": round(m[key].get("wall_time", 0.0), 3),
    })

summary = {
    "baseline": baseline.get("metadata", {}),
    "mtp": mtp.get("metadata", {}),
    "comparison": rows,
}
with open(out_path, "w") as f:
    json.dump(summary, f, indent=2)

print("\n=== Gemma 4 MTP comparison ===")
print("ctx\tconc\tbaseline\tmtp\tspeedup\taccept")
for r in rows:
    print(
        f"{r['context_tokens']}\t{r['concurrency']}\t"
        f"{r['baseline_tps']}\t{r['mtp_tps']}\t"
        f"{r['speedup']}x\t{r['mtp_accept_rate']}"
    )
print(f"\nSummary saved to {out_path}")
PY
