#!/usr/bin/env bash
set -euo pipefail

VARIANT="${VARIANT:-26B-A4B}"
VARIANT_KEY="$(printf '%s' "$VARIANT" | tr '[:lower:]' '[:upper:]')"

case "$VARIANT_KEY" in
  26B-A4B | 26B | MOE)
    DEFAULT_MODEL="google/gemma-4-26B-A4B-it"
    DEFAULT_DRAFT_MODEL="google/gemma-4-26B-A4B-it-assistant"
    DEFAULT_SERVED_MODEL_NAME="google/gemma-4-26B-A4B-it"
    ;;
  31B | DENSE-31B)
    DEFAULT_MODEL="google/gemma-4-31B-it"
    DEFAULT_DRAFT_MODEL="google/gemma-4-31B-it-assistant"
    DEFAULT_SERVED_MODEL_NAME="google/gemma-4-31B-it"
    ;;
  E4B | 4B)
    DEFAULT_MODEL="google/gemma-4-E4B-it"
    DEFAULT_DRAFT_MODEL="google/gemma-4-E4B-it-assistant"
    DEFAULT_SERVED_MODEL_NAME="google/gemma-4-E4B-it"
    ;;
  E2B | 2B)
    DEFAULT_MODEL="google/gemma-4-E2B-it"
    DEFAULT_DRAFT_MODEL="google/gemma-4-E2B-it-assistant"
    DEFAULT_SERVED_MODEL_NAME="google/gemma-4-E2B-it"
    ;;
  *)
    echo "Unknown VARIANT=$VARIANT. Use 26B-A4B, 31B, E4B, or E2B." >&2
    exit 2
    ;;
esac

MODEL_PATH="${MODEL_PATH:-$DEFAULT_MODEL}"
DRAFT_MODEL_PATH="${DRAFT_MODEL_PATH:-$DEFAULT_DRAFT_MODEL}"
SERVED_MODEL_NAME="${SERVED_MODEL_NAME:-$DEFAULT_SERVED_MODEL_NAME}"
HOST="${HOST:-0.0.0.0}"
PORT="${PORT:-30000}"
if [[ -z "${CUDA_HOME:-}" ]]; then
  if [[ -x /usr/local/cuda-13.0/bin/nvcc ]]; then
    CUDA_HOME="/usr/local/cuda-13.0"
  else
    CUDA_HOME="/usr/local/cuda"
  fi
fi
TP_SIZE="${TP_SIZE:-1}"
MEM_FRACTION_STATIC="${MEM_FRACTION_STATIC:-0.85}"
CONTEXT_LENGTH="${CONTEXT_LENGTH:-262144}"
MAX_RUNNING_REQUESTS="${MAX_RUNNING_REQUESTS:-48}"
ENABLE_MTP="${ENABLE_MTP:-0}"
SPECULATIVE_NUM_STEPS="${SPECULATIVE_NUM_STEPS:-5}"
SPECULATIVE_NUM_DRAFT_TOKENS="${SPECULATIVE_NUM_DRAFT_TOKENS:-6}"
SPECULATIVE_EAGLE_TOPK="${SPECULATIVE_EAGLE_TOPK:-1}"
ATTENTION_BACKEND="${ATTENTION_BACKEND:-}"
MOE_RUNNER_BACKEND="${MOE_RUNNER_BACKEND:-}"
KV_CACHE_DTYPE="${KV_CACHE_DTYPE:-}"
CUDA_GRAPH_MAX_BS="${CUDA_GRAPH_MAX_BS:-}"
GPU_DEVICES="${GPU_DEVICES:-}"
HF_HOME="${HF_HOME:-/workspace/.cache/huggingface}"
HUGGINGFACE_HUB_CACHE="${HUGGINGFACE_HUB_CACHE:-$HF_HOME/hub}"
TRANSFORMERS_CACHE="${TRANSFORMERS_CACHE:-$HF_HOME/transformers}"
PIP_CACHE_DIR="${PIP_CACHE_DIR:-/workspace/.cache/pip}"

export HF_HOME
export HUGGINGFACE_HUB_CACHE
export TRANSFORMERS_CACHE
export PIP_CACHE_DIR
export CUDA_HOME
export HF_HUB_DISABLE_XET="${HF_HUB_DISABLE_XET:-1}"
export SAFETENSORS_FAST_GPU="${SAFETENSORS_FAST_GPU:-1}"
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-8}"
export FLASHINFER_DISABLE_VERSION_CHECK="${FLASHINFER_DISABLE_VERSION_CHECK:-1}"
export SGLANG_DISABLE_DEEP_GEMM="${SGLANG_DISABLE_DEEP_GEMM:-1}"
export SGLANG_ENABLE_DEEP_GEMM="${SGLANG_ENABLE_DEEP_GEMM:-0}"
export SGLANG_JIT_DEEPGEMM_PRECOMPILE="${SGLANG_JIT_DEEPGEMM_PRECOMPILE:-0}"

if [[ -d "$CUDA_HOME/bin" ]]; then
  export PATH="$CUDA_HOME/bin:$PATH"
fi
if [[ -d "$CUDA_HOME/lib64" ]]; then
  export LD_LIBRARY_PATH="$CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}"
fi

if [[ "${DRY_RUN:-0}" != "1" ]]; then
  mkdir -p "$HF_HOME" "$HUGGINGFACE_HUB_CACHE" "$TRANSFORMERS_CACHE" "$PIP_CACHE_DIR"
fi

if [[ -n "${VIRTUAL_ENV:-}" ]]; then
  site_packages="$VIRTUAL_ENV/lib/python$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')/site-packages"
  torch_lib_dir="$site_packages/torch/lib"
  nvidia_cu13_lib_dir="$site_packages/nvidia/cu13/lib"
  nvidia_cu13_include_dir="$site_packages/nvidia/cu13/include"

  if [[ -d "$torch_lib_dir" || -d "$nvidia_cu13_lib_dir" ]]; then
    export LD_LIBRARY_PATH="$torch_lib_dir:$nvidia_cu13_lib_dir:${LD_LIBRARY_PATH:-}"
  fi
  if [[ -d "$nvidia_cu13_lib_dir" ]]; then
    export LIBRARY_PATH="$nvidia_cu13_lib_dir:${LIBRARY_PATH:-}"
  fi
  if [[ -d "$nvidia_cu13_include_dir" ]]; then
    export CPATH="$nvidia_cu13_include_dir:${CPATH:-}"
    export CPLUS_INCLUDE_PATH="$nvidia_cu13_include_dir:${CPLUS_INCLUDE_PATH:-}"
  fi
fi

if [[ -n "$GPU_DEVICES" ]]; then
  export CUDA_VISIBLE_DEVICES="$GPU_DEVICES"
fi

args=(
  serve
  --model-path "$MODEL_PATH"
  --served-model-name "$SERVED_MODEL_NAME"
  --host "$HOST"
  --port "$PORT"
  --tp "$TP_SIZE"
  --trust-remote-code
  --reasoning-parser gemma4
  --tool-call-parser gemma4
  --mem-fraction-static "$MEM_FRACTION_STATIC"
  --context-length "$CONTEXT_LENGTH"
  --max-running-requests "$MAX_RUNNING_REQUESTS"
  --enable-metrics
)

if [[ -n "$ATTENTION_BACKEND" ]]; then
  args+=(--attention-backend "$ATTENTION_BACKEND")
fi

if [[ -n "$MOE_RUNNER_BACKEND" ]]; then
  args+=(--moe-runner-backend "$MOE_RUNNER_BACKEND")
fi

if [[ -n "$KV_CACHE_DTYPE" ]]; then
  args+=(--kv-cache-dtype "$KV_CACHE_DTYPE")
fi

if [[ -n "$CUDA_GRAPH_MAX_BS" ]]; then
  args+=(--cuda-graph-max-bs "$CUDA_GRAPH_MAX_BS")
fi

if [[ "$ENABLE_MTP" == "1" ]]; then
  args+=(
    --speculative-algorithm NEXTN
    --speculative-draft-model-path "$DRAFT_MODEL_PATH"
    --speculative-num-steps "$SPECULATIVE_NUM_STEPS"
    --speculative-num-draft-tokens "$SPECULATIVE_NUM_DRAFT_TOKENS"
    --speculative-eagle-topk "$SPECULATIVE_EAGLE_TOPK"
  )
fi

echo "Serving $MODEL_PATH as $SERVED_MODEL_NAME on $HOST:$PORT with SGLang"
echo "variant=$VARIANT tp=$TP_SIZE cuda_visible_devices=${CUDA_VISIBLE_DEVICES:-all} cuda_home=$CUDA_HOME context_length=$CONTEXT_LENGTH mem_fraction_static=$MEM_FRACTION_STATIC max_running_requests=$MAX_RUNNING_REQUESTS mtp=$ENABLE_MTP"
if [[ "$ENABLE_MTP" == "1" ]]; then
  echo "draft=$DRAFT_MODEL_PATH steps=$SPECULATIVE_NUM_STEPS draft_tokens=$SPECULATIVE_NUM_DRAFT_TOKENS topk=$SPECULATIVE_EAGLE_TOPK"
fi
printf 'Command: sglang'
printf ' %q' "${args[@]}"
printf '\n'

if [[ "${DRY_RUN:-0}" == "1" ]]; then
  exit 0
fi

exec sglang "${args[@]}"
