|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import time |
| 5 | +from collections.abc import Iterator |
| 6 | +from pathlib import Path |
| 7 | +from typing import Literal |
| 8 | + |
| 9 | +import pytest |
| 10 | +import urllib3 |
| 11 | +import vectorize_client as v |
| 12 | +from vectorize_client import ApiClient, ApiException, RetrieveDocumentsRequest |
| 13 | + |
| 14 | + |
| 15 | +@pytest.fixture(scope="session") |
| 16 | +def api_token() -> str: |
| 17 | + token = os.getenv("VECTORIZE_TOKEN", "wdcwd") |
| 18 | + if not token: |
| 19 | + msg = "Please set the VECTORIZE_TOKEN environment variable" |
| 20 | + raise ValueError(msg) |
| 21 | + return token |
| 22 | + |
| 23 | + |
| 24 | +@pytest.fixture(scope="session") |
| 25 | +def org_id() -> str: |
| 26 | + org = os.getenv("VECTORIZE_ORG", "wdcd") |
| 27 | + if not org: |
| 28 | + msg = "Please set the VECTORIZE_ORG environment variable" |
| 29 | + raise ValueError(msg) |
| 30 | + return org |
| 31 | + |
| 32 | + |
| 33 | +@pytest.fixture(scope="session") |
| 34 | +def environment() -> Literal["prod", "dev", "local", "staging"]: |
| 35 | + env = os.getenv("VECTORIZE_ENV", "prod") |
| 36 | + if env not in ["prod", "dev", "local", "staging"]: |
| 37 | + msg = "Invalid VECTORIZE_ENV environment variable." |
| 38 | + raise ValueError(msg) |
| 39 | + return env |
| 40 | + |
| 41 | + |
| 42 | +@pytest.fixture(scope="session") |
| 43 | +def api_client(api_token: str, environment: str) -> Iterator[ApiClient]: |
| 44 | + header_name = None |
| 45 | + header_value = None |
| 46 | + if environment == "prod": |
| 47 | + host = "https://api.vectorize.io/v1" |
| 48 | + elif environment == "dev": |
| 49 | + host = "https://api-dev.vectorize.io/v1" |
| 50 | + elif environment == "local": |
| 51 | + host = "http://localhost:3000/api" |
| 52 | + header_name = "x-lambda-api-key" |
| 53 | + header_value = api_token |
| 54 | + else: |
| 55 | + host = "https://api-staging.vectorize.io/v1" |
| 56 | + |
| 57 | + with v.ApiClient( |
| 58 | + v.Configuration(host=host, access_token=api_token, debug=True), |
| 59 | + header_name, |
| 60 | + header_value, |
| 61 | + ) as api: |
| 62 | + yield api |
| 63 | + |
| 64 | + |
| 65 | +@pytest.fixture(scope="session") |
| 66 | +def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]: |
| 67 | + pipelines = v.PipelinesApi(api_client) |
| 68 | + |
| 69 | + connectors_api = v.SourceConnectorsApi(api_client) |
| 70 | + response = connectors_api.create_source_connector( |
| 71 | + org_id, |
| 72 | + v.CreateSourceConnectorRequest( |
| 73 | + v.FileUpload(name="from api", type="FILE_UPLOAD") |
| 74 | + ), |
| 75 | + ) |
| 76 | + source_connector_id = response.connector.id |
| 77 | + logging.info("Created source connector %s", source_connector_id) |
| 78 | + |
| 79 | + uploads_api = v.UploadsApi(api_client) |
| 80 | + upload_response = uploads_api.start_file_upload_to_connector( |
| 81 | + org_id, |
| 82 | + source_connector_id, |
| 83 | + v.StartFileUploadToConnectorRequest( |
| 84 | + name="research.pdf", |
| 85 | + content_type="application/pdf", |
| 86 | + metadata=json.dumps({"created-from-api": True}), |
| 87 | + ), |
| 88 | + ) |
| 89 | + |
| 90 | + http = urllib3.PoolManager() |
| 91 | + this_dir = Path(__file__).parent |
| 92 | + file_path = this_dir / "research.pdf" |
| 93 | + |
| 94 | + with file_path.open("rb") as f: |
| 95 | + http_response = http.request( |
| 96 | + "PUT", |
| 97 | + upload_response.upload_url, |
| 98 | + body=f, |
| 99 | + headers={ |
| 100 | + "Content-Type": "application/pdf", |
| 101 | + "Content-Length": str(file_path.stat().st_size), |
| 102 | + }, |
| 103 | + ) |
| 104 | + if http_response.status != 200: |
| 105 | + msg = "Upload failed:" |
| 106 | + raise ValueError(msg) |
| 107 | + else: |
| 108 | + logging.info("Upload successful") |
| 109 | + |
| 110 | + ai_platforms = v.AIPlatformConnectorsApi(api_client).get_ai_platform_connectors( |
| 111 | + org_id |
| 112 | + ) |
| 113 | + builtin_ai_platform = next( |
| 114 | + c.id for c in ai_platforms.ai_platform_connectors if c.type == "VECTORIZE" |
| 115 | + ) |
| 116 | + logging.info("Using AI platform %s", builtin_ai_platform) |
| 117 | + |
| 118 | + vector_databases = v.DestinationConnectorsApi( |
| 119 | + api_client |
| 120 | + ).get_destination_connectors(org_id) |
| 121 | + builtin_vector_db = next( |
| 122 | + c.id for c in vector_databases.destination_connectors if c.type == "VECTORIZE" |
| 123 | + ) |
| 124 | + logging.info("Using destination connector %s", builtin_vector_db) |
| 125 | + |
| 126 | + pipeline_response = pipelines.create_pipeline( |
| 127 | + org_id, |
| 128 | + v.PipelineConfigurationSchema( |
| 129 | + source_connectors=[ |
| 130 | + v.PipelineSourceConnectorSchema( |
| 131 | + id=source_connector_id, |
| 132 | + type=v.SourceConnectorType.FILE_UPLOAD, |
| 133 | + config={}, |
| 134 | + ) |
| 135 | + ], |
| 136 | + destination_connector=v.PipelineDestinationConnectorSchema( |
| 137 | + id=builtin_vector_db, |
| 138 | + type="VECTORIZE", |
| 139 | + config={}, |
| 140 | + ), |
| 141 | + ai_platform_connector=v.PipelineAIPlatformConnectorSchema( |
| 142 | + id=builtin_ai_platform, |
| 143 | + type="VECTORIZE", |
| 144 | + config={}, |
| 145 | + ), |
| 146 | + pipeline_name="Test pipeline", |
| 147 | + schedule=v.ScheduleSchema(type="manual"), |
| 148 | + ), |
| 149 | + ) |
| 150 | + pipeline_id = pipeline_response.data.id |
| 151 | + |
| 152 | + # Wait for the pipeline to be created |
| 153 | + request = RetrieveDocumentsRequest( |
| 154 | + question="query", |
| 155 | + num_results=2, |
| 156 | + ) |
| 157 | + start = time.time() |
| 158 | + while True: |
| 159 | + try: |
| 160 | + response = pipelines.retrieve_documents(org_id, pipeline_id, request) |
| 161 | + except ApiException as e: |
| 162 | + if "503" not in str(e): |
| 163 | + raise |
| 164 | + else: |
| 165 | + docs = response.documents |
| 166 | + if len(docs) == 2: |
| 167 | + break |
| 168 | + if time.time() - start > 180: |
| 169 | + msg = "Docs not retrieved in time" |
| 170 | + raise RuntimeError(msg) |
| 171 | + time.sleep(1) |
| 172 | + |
| 173 | + logging.info("Created pipeline %s", pipeline_id) |
| 174 | + |
| 175 | + yield pipeline_id |
| 176 | + |
| 177 | + try: |
| 178 | + pipelines.delete_pipeline(org_id, pipeline_id) |
| 179 | + except Exception: |
| 180 | + logging.exception("Failed to delete pipeline %s", pipeline_id) |
0 commit comments