Skip to content

Add support for ColSmol #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Byaldi is [RAGatouille](https://github.com/answerdotai/ragatouille)'s mini siste

First, a warning: This is a pre-release library, using uncompressed indexes and lacking other kinds of refinements.

Currently, we support all models supported by the underlying [colpali-engine](https://github.com/illuin-tech/colpali), including the newer, and better, ColQwen2 checkpoints, such as `vidore/colqwen2-v1.0`. Broadly, the aim is for byaldi to support all ColVLM models.
Currently, we support all models supported by the underlying [colpali-engine](https://github.com/illuin-tech/colpali), including the newer, and better, ColQwen2 checkpoints, such as `vidore/colqwen2-v1.0`. You can also use `byaldi` to leverage ColSmol models if you have hardware requirements (`vidore/colSmol-256M`, `vidore/colSmol-500M`). Broadly, the aim is for `byaldi` to support all ColVLM models.

Additional backends will be supported in future updates. As byaldi exists to facilitate the adoption of multi-modal retrievers, we intend to also add support for models such as [VisRAG](https://github.com/openbmb/visrag).

Expand Down
1 change: 0 additions & 1 deletion byaldi/RAGModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from PIL import Image

from byaldi.colpali import ColPaliModel

from byaldi.objects import Result

# Optional langchain integration
Expand Down
30 changes: 29 additions & 1 deletion byaldi/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

import srsly
import torch
from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor
from colpali_engine.models import (
ColIdefics3,
ColIdefics3Processor,
ColPali,
ColPaliProcessor,
ColQwen2,
ColQwen2Processor,
)
from pdf2image import convert_from_path
from PIL import Image

Expand Down Expand Up @@ -35,6 +42,7 @@ def __init__(
if (
"colpali" not in pretrained_model_name_or_path.lower()
and "colqwen2" not in pretrained_model_name_or_path.lower()
and "colsmol" not in pretrained_model_name_or_path.lower()
):
raise ValueError(
"This pre-release version of Byaldi only supports ColPali and ColQwen2 for now. Incorrect model name specified."
Expand Down Expand Up @@ -89,6 +97,18 @@ def __init__(
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
elif "colsmol" in pretrained_model_name_or_path.lower():
self.model = ColIdefics3.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
self.model = self.model.eval()

if "colpali" in pretrained_model_name_or_path.lower():
Expand All @@ -107,6 +127,14 @@ def __init__(
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)
elif "colsmol" in pretrained_model_name_or_path.lower():
self.processor = cast(
ColIdefics3Processor,
ColIdefics3Processor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)

self.device = device
if device != "cuda" and not (
Expand Down
2 changes: 1 addition & 1 deletion byaldi/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
_all__ = []

try:
from byaldi.integrations._langchain import ByaldiLangChainRetriever
from byaldi.integrations._langchain import ByaldiLangChainRetriever # noqa: F401

_all__.append("ByaldiLangChainRetriever")
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ maintainers = [
]

dependencies = [
"colpali-engine>=0.3.4,<0.4.0",
"colpali-engine>=0.3.7,<0.4.0",
"ml-dtypes",
"mteb==1.6.35",
"ninja",
"pdf2image",
"srsly",
"torch",
"transformers>=4.42.0",
"transformers>=4.47.0",
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def colpali_rag_model() -> Generator[RAGMultiModalModel, None, None]:
device = get_torch_device("auto")
print(f"Using device: {device}")
yield RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device=device)
yield RAGMultiModalModel.from_pretrained("vidore/colpali-v1.3", device=device)
tear_down_torch()


Expand Down
2 changes: 1 addition & 1 deletion tests/test_colqwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def colqwen_rag_model() -> Generator[RAGMultiModalModel, None, None]:
device = get_torch_device("auto")
print(f"Using device: {device}")
yield RAGMultiModalModel.from_pretrained("vidore/colqwen2-v0.1", device=device)
yield RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0", device=device)
tear_down_torch()


Expand Down
23 changes: 23 additions & 0 deletions tests/test_colsmol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Generator

import pytest
from colpali_engine.models import ColIdefics3
from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch

from byaldi import RAGMultiModalModel
from byaldi.colpali import ColPaliModel


@pytest.fixture(scope="module")
def colsmol_rag_model() -> Generator[RAGMultiModalModel, None, None]:
device = get_torch_device("auto")
print(f"Using device: {device}")
yield RAGMultiModalModel.from_pretrained("vidore/colSmol-256M", device=device)
tear_down_torch()


@pytest.mark.slow
def test_load_colqwen_from_pretrained(colsmol_rag_model: RAGMultiModalModel):
assert isinstance(colsmol_rag_model, RAGMultiModalModel)
assert isinstance(colsmol_rag_model.model, ColPaliModel)
assert isinstance(colsmol_rag_model.model.model, ColIdefics3)