#!/usr/bin/env bash
set -euo pipefail

VENV_DIR="${VENV_DIR:-/root/venvs/sglang-gemma4-mtp}"
PYTHON_BIN="${PYTHON_BIN:-python3}"
if [[ -z "${CUDA_HOME:-}" ]]; then
  if [[ -x /usr/local/cuda-13.0/bin/nvcc ]]; then
    CUDA_HOME="/usr/local/cuda-13.0"
  else
    CUDA_HOME="/usr/local/cuda"
  fi
fi
PIP_CACHE_DIR="${PIP_CACHE_DIR:-/workspace/.cache/pip}"

# PR #24436 merged on 2026-05-07, but pinning the head keeps this experiment
# reproducible if SGLang main moves underneath us.
SGLANG_REF="${SGLANG_REF:-bcf8d100d06a70009f5b591652464f1e7ff86116}"
SGLANG_SPEC="${SGLANG_SPEC:-git+https://github.com/sgl-project/sglang.git@${SGLANG_REF}#subdirectory=python}"
TRANSFORMERS_REF="${TRANSFORMERS_REF:-2c7d385621c80fee70c1472f3a622fcba2c93fb9}"
TRANSFORMERS_SPEC="${TRANSFORMERS_SPEC:-git+https://github.com/huggingface/transformers.git@${TRANSFORMERS_REF}}"
TORCH_INDEX_URL="${TORCH_INDEX_URL:-https://download.pytorch.org/whl/cu130}"

export CUDA_HOME
export PIP_CACHE_DIR
export HF_HUB_DISABLE_XET="${HF_HUB_DISABLE_XET:-1}"

if [[ -d "$CUDA_HOME/bin" ]]; then
  export PATH="$CUDA_HOME/bin:$PATH"
fi
if [[ -d "$CUDA_HOME/lib64" ]]; then
  export LD_LIBRARY_PATH="$CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}"
fi

if [[ -f "$HOME/.cargo/env" ]]; then
  # Prefer rustup's toolchain over Ubuntu's older cargo/rustc packages.
  # SGLang's current dependency graph contains Rust 2024 crates.
  # shellcheck disable=SC1090
  source "$HOME/.cargo/env"
fi
export PATH="$HOME/.cargo/bin:$PATH"

missing=()
need_cmd() {
  if ! command -v "$1" >/dev/null 2>&1; then
    missing+=("$2")
  fi
}

need_cmd git git
need_cmd curl curl
need_cmd protoc protobuf-compiler

if ! "$PYTHON_BIN" -m venv --help >/dev/null 2>&1; then
  missing+=("python3.10-venv")
fi

if ! command -v rustc >/dev/null 2>&1; then
  missing+=("rustup-managed rustc")
else
  rust_version="$(rustc --version | awk '{print $2}')"
  IFS=. read -r rust_major rust_minor _ <<<"$rust_version"
  if (( rust_major < 1 || (rust_major == 1 && rust_minor < 85) )); then
    missing+=("rustup-managed rustc>=1.85")
  fi
fi

if (( ${#missing[@]} > 0 )); then
  printf 'Missing prerequisites: %s\n' "${missing[*]}" >&2
  cat >&2 <<'EOF'
On a fresh Ubuntu Blackwell box, install the system pieces with:

  apt-get update
  apt-get install -y python3.10-venv git curl ca-certificates build-essential pkg-config protobuf-compiler
  curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
  source ~/.cargo/env

Then rerun this installer.
EOF
  exit 2
fi

mkdir -p "$(dirname "$VENV_DIR")" "$PIP_CACHE_DIR"
"$PYTHON_BIN" -m venv "$VENV_DIR"

# shellcheck disable=SC1090
source "$VENV_DIR/bin/activate"

python -m pip install --upgrade pip setuptools wheel packaging ninja
python -m pip install --upgrade torch torchvision torchaudio --index-url "$TORCH_INDEX_URL"
python -m pip install --upgrade "$TRANSFORMERS_SPEC"
python -m pip install --upgrade "$SGLANG_SPEC"
python -m pip install --upgrade "$TRANSFORMERS_SPEC"

site_packages="$(python - <<'PY'
import site
print(site.getsitepackages()[0])
PY
)"
nvidia_cu13_lib_dir="$site_packages/nvidia/cu13/lib"
nvidia_cu13_include_dir="$site_packages/nvidia/cu13/include"
if [[ -d "$nvidia_cu13_lib_dir" ]]; then
  export LD_LIBRARY_PATH="$nvidia_cu13_lib_dir:${LD_LIBRARY_PATH:-}"
  export LIBRARY_PATH="$nvidia_cu13_lib_dir:${LIBRARY_PATH:-}"
fi
if [[ -d "$nvidia_cu13_include_dir" ]]; then
  export CPATH="$nvidia_cu13_include_dir:${CPATH:-}"
  export CPLUS_INCLUDE_PATH="$nvidia_cu13_include_dir:${CPLUS_INCLUDE_PATH:-}"
fi

kernel_version="$(python -m pip show sglang-kernel | awk '/^Version: / {print $2}')"
arch="$(uname -m)"

if [[ -n "$kernel_version" && "$arch" == "x86_64" && "$TORCH_INDEX_URL" == *"cu130"* ]]; then
  python -m pip install --upgrade \
    "https://github.com/sgl-project/whl/releases/download/v${kernel_version}/sglang_kernel-${kernel_version}+cu130-cp310-abi3-manylinux2014_x86_64.whl"
fi

python - <<'PY'
import torch
import transformers
import sglang

print("torch", torch.__version__)
print("cuda available", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu", torch.cuda.get_device_name(0))
print("transformers", transformers.__version__)
print("sglang", getattr(sglang, "__version__", "unknown"))
PY

echo "Gemma 4 MTP SGLang environment ready at $VENV_DIR"
echo "Activate with: source $VENV_DIR/bin/activate"
