Qwen3-VL mode working; /unload; normal model loading times
This commit is contained in:
@@ -18,6 +18,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
AutoModelForVision2Seq,
|
||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
@@ -177,6 +178,10 @@ class EmbeddingResponse(BaseModel):
|
|||||||
usage: Usage
|
usage: Usage
|
||||||
|
|
||||||
|
|
||||||
|
class PreloadRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -248,28 +253,29 @@ def _load_model_locked(model_id: str):
|
|||||||
# Load Generation Model
|
# Load Generation Model
|
||||||
# Check if it is a VL model
|
# Check if it is a VL model
|
||||||
if "VL" in model_id:
|
if "VL" in model_id:
|
||||||
# Attempt to load as VL
|
# Use AutoModelForVision2Seq for VL models
|
||||||
# Using AutoModelForVision2Seq or AutoModelForCausalLM
|
# The configuration class Qwen3VLConfig requires Vision2Seq or AutoModel
|
||||||
# depending on the specific model support in transformers
|
|
||||||
try:
|
try:
|
||||||
from transformers import Qwen2VLForConditionalGeneration
|
print(f"Loading {model_id} with AutoModelForVision2Seq...")
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
model_class = Qwen2VLForConditionalGeneration
|
model_id,
|
||||||
except ImportError:
|
torch_dtype=dtype,
|
||||||
# Fallback to AutoModel if specific class not available
|
device_map=device_map,
|
||||||
model_class = AutoModelForCausalLM
|
attn_implementation=attn_impl,
|
||||||
|
trust_remote_code=True,
|
||||||
# Note: We use AutoModelForCausalLM for broad compatibility.
|
low_cpu_mem_usage=True,
|
||||||
# Qwen2-VL requires Qwen2VLForConditionalGeneration for vision.
|
).eval()
|
||||||
# We will try AutoModelForCausalLM first.
|
except Exception as e:
|
||||||
|
print(f"Vision2Seq failed: {e}. Fallback to AutoModel...")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
# Fallback to generic AutoModel if Vision2Seq fails
|
||||||
model_id,
|
model = AutoModel.from_pretrained(
|
||||||
torch_dtype=dtype,
|
model_id,
|
||||||
device_map=device_map,
|
torch_dtype=dtype,
|
||||||
attn_implementation=attn_impl,
|
device_map=device_map,
|
||||||
trust_remote_code=True, # Often needed for new architectures
|
attn_implementation=attn_impl,
|
||||||
).eval()
|
trust_remote_code=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).eval()
|
||||||
|
|
||||||
# Processor/Tokenizer
|
# Processor/Tokenizer
|
||||||
try:
|
try:
|
||||||
@@ -284,12 +290,14 @@ def _load_model_locked(model_id: str):
|
|||||||
_loaded_model_type = "generation"
|
_loaded_model_type = "generation"
|
||||||
else:
|
else:
|
||||||
# Standard Text Model (GPT-OSS)
|
# Standard Text Model (GPT-OSS)
|
||||||
|
print(f"Loading {model_id} with AutoModelForCausalLM...")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
attn_implementation=attn_impl,
|
attn_implementation=attn_impl,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
).eval()
|
).eval()
|
||||||
processor = AutoTokenizer.from_pretrained(
|
processor = AutoTokenizer.from_pretrained(
|
||||||
model_id, trust_remote_code=True
|
model_id, trust_remote_code=True
|
||||||
@@ -367,6 +375,41 @@ def _extract_embeddings(outputs) -> torch.Tensor:
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/preload")
|
||||||
|
def preload_model(request: PreloadRequest):
|
||||||
|
model_id = request.model.strip()
|
||||||
|
if model_id not in ALLOWED_MODEL_IDS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Model {model_id} not in allowed models.",
|
||||||
|
)
|
||||||
|
|
||||||
|
with _model_lock:
|
||||||
|
try:
|
||||||
|
_ensure_model_loaded(model_id)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to load model: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"loaded_model_id": _loaded_model_id,
|
||||||
|
"vram_bytes": _current_vram_info(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/unload")
|
||||||
|
def unload_model():
|
||||||
|
with _model_lock:
|
||||||
|
stats = _unload_model_locked()
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"vram_bytes": _current_vram_info(),
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health():
|
def health():
|
||||||
cuda_ok = bool(torch.cuda.is_available())
|
cuda_ok = bool(torch.cuda.is_available())
|
||||||
|
|||||||
Reference in New Issue
Block a user