diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4814fbe..54c90de 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,4 +38,4 @@ repos: hooks: - id: mypy files: ^arangoasync/ - additional_dependencies: ['types-requests', "types-setuptools"] + additional_dependencies: ["types-requests", "types-setuptools"] diff --git a/arangoasync/__init__.py b/arangoasync/__init__.py index 58f3ace..5cc9189 100644 --- a/arangoasync/__init__.py +++ b/arangoasync/__init__.py @@ -1 +1,5 @@ +import logging + from .version import __version__ + +logger = logging.getLogger(__name__) diff --git a/arangoasync/auth.py b/arangoasync/auth.py new file mode 100644 index 0000000..ff703cf --- /dev/null +++ b/arangoasync/auth.py @@ -0,0 +1,74 @@ +__all__ = [ + "Auth", + "JwtToken", +] + +from dataclasses import dataclass + +import jwt + + +@dataclass +class Auth: + """Authentication details for the ArangoDB instance. + + Attributes: + username (str): Username. + password (str): Password. + encoding (str): Encoding for the password (default: utf-8) + """ + + username: str + password: str + encoding: str = "utf-8" + + +class JwtToken: + """JWT token. + + Args: + token (str | bytes): JWT token. + + Raises: + TypeError: If the token type is not str or bytes. + JWTExpiredError: If the token expired. + """ + + def __init__(self, token: str | bytes) -> None: + self._token = token + self._validate() + + @property + def token(self) -> str | bytes: + """Get token.""" + return self._token + + @token.setter + def token(self, token: str | bytes) -> None: + """Set token. + + Raises: + jwt.ExpiredSignatureError: If the token expired. + """ + self._token = token + self._validate() + + def _validate(self) -> None: + """Validate the token.""" + if type(self._token) not in (str, bytes): + raise TypeError("Token must be str or bytes") + + jwt_payload = jwt.decode( + self._token, + issuer="arangodb", + algorithms=["HS256"], + options={ + "require_exp": True, + "require_iat": True, + "verify_iat": True, + "verify_exp": True, + "verify_signature": False, + }, + ) + + self._token_exp = jwt_payload["exp"] diff --git a/arangoasync/compression.py b/arangoasync/compression.py new file mode 100644 index 0000000..1151149 --- /dev/null +++ b/arangoasync/compression.py @@ -0,0 +1,118 @@ +__all__ = [ + "AcceptEncoding", + "ContentEncoding", + "CompressionManager", + "DefaultCompressionManager", +] + +import zlib +from abc import ABC, abstractmethod +from enum import Enum, auto +from typing import Optional + + +class AcceptEncoding(Enum): + """Valid accepted encodings for the Accept-Encoding header.""" + + DEFLATE = auto() + GZIP = auto() + IDENTITY = auto() + + +class ContentEncoding(Enum): + """Valid content encodings for the Content-Encoding header.""" + + DEFLATE = auto() + GZIP = auto() + + +class CompressionManager(ABC): # pragma: no cover + """Abstract base class for handling request/response compression.""" + + @abstractmethod + def needs_compression(self, data: str | bytes) -> bool: + """Determine if the data needs to be compressed + + Args: + data (str | bytes): Data to check + + Returns: + bool: True if the data needs to be compressed + """ + raise NotImplementedError + + @abstractmethod + def compress(self, data: str | bytes) -> bytes: + """Compress the data + + Args: + data (str | bytes): Data to compress + + Returns: + bytes: Compressed data + """ + raise NotImplementedError + + @abstractmethod + def content_encoding(self) -> str: + """Return the content encoding. + + This is the value of the Content-Encoding header in the HTTP request. + Must match the encoding used in the compress method. + + Returns: + str: Content encoding + """ + raise NotImplementedError + + @abstractmethod + def accept_encoding(self) -> str | None: + """Return the accept encoding. + + This is the value of the Accept-Encoding header in the HTTP request. + Currently, only deflate and "gzip" are supported. + + Returns: + str: Accept encoding + """ + raise NotImplementedError + + +class DefaultCompressionManager(CompressionManager): + """Compress requests using the deflate algorithm. + + Args: + threshold (int): Will compress requests to the server if + the size of the request body (in bytes) is at least the value of this option. + Setting it to -1 will disable request compression (default). + level (int): Compression level. Defaults to 6. + accept (str | None): Accepted encoding. By default, there is + no compression of responses. + """ + + def __init__( + self, + threshold: int = -1, + level: int = 6, + accept: Optional[AcceptEncoding] = None, + ) -> None: + self._threshold = threshold + self._level = level + self._content_encoding = ContentEncoding.DEFLATE.name.lower() + self._accept_encoding = accept.name.lower() if accept else None + + def needs_compression(self, data: str | bytes) -> bool: + return self._threshold != -1 and len(data) >= self._threshold + + def compress(self, data: str | bytes) -> bytes: + if data is not None: + if isinstance(data, bytes): + return zlib.compress(data, self._level) + return zlib.compress(data.encode("utf-8"), self._level) + return b"" + + def content_encoding(self) -> str: + return self._content_encoding + + def accept_encoding(self) -> str | None: + return self._accept_encoding diff --git a/arangoasync/connection.py b/arangoasync/connection.py new file mode 100644 index 0000000..bf6ef8a --- /dev/null +++ b/arangoasync/connection.py @@ -0,0 +1,174 @@ +__all__ = [ + "BaseConnection", + "BasicConnection", +] + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from arangoasync.auth import Auth +from arangoasync.compression import CompressionManager, DefaultCompressionManager +from arangoasync.exceptions import ( + ClientConnectionError, + ConnectionAbortedError, + ServerConnectionError, +) +from arangoasync.http import HTTPClient +from arangoasync.request import Method, Request +from arangoasync.resolver import HostResolver +from arangoasync.response import Response + + +class BaseConnection(ABC): + """Blueprint for connection to a specific ArangoDB database. + + Args: + sessions (list): List of client sessions. + host_resolver (HostResolver): Host resolver. + http_client (HTTPClient): HTTP client. + db_name (str): Database name. + compression (CompressionManager | None): Compression manager. + """ + + def __init__( + self, + sessions: List[Any], + host_resolver: HostResolver, + http_client: HTTPClient, + db_name: str, + compression: Optional[CompressionManager] = None, + ) -> None: + self._sessions = sessions + self._db_endpoint = f"/_db/{db_name}" + self._host_resolver = host_resolver + self._http_client = http_client + self._db_name = db_name + self._compression = compression or DefaultCompressionManager() + + @property + def db_name(self) -> str: + """Return the database name.""" + return self._db_name + + def prep_response(self, request: Request, resp: Response) -> Response: + """Prepare response for return. + + Args: + request (Request): Request object. + resp (Response): Response object. + + Returns: + Response: Response object + + Raises: + ServerConnectionError: If the response status code is not successful. + """ + resp.is_success = 200 <= resp.status_code < 300 + if not resp.is_success: + raise ServerConnectionError(resp, request) + return resp + + async def process_request(self, request: Request) -> Response: + """Process request, potentially trying multiple hosts. + + Args: + request (Request): Request object. + + Returns: + Response: Response object. + + Raises: + ConnectionAbortedError: If can't connect to host(s) within limit. + """ + + ex_host_index = -1 + host_index = self._host_resolver.get_host_index() + for tries in range(self._host_resolver.max_tries): + try: + resp = await self._http_client.send_request( + self._sessions[host_index], request + ) + return self.prep_response(request, resp) + except ClientConnectionError: + ex_host_index = host_index + host_index = self._host_resolver.get_host_index() + if ex_host_index == host_index: + self._host_resolver.change_host() + host_index = self._host_resolver.get_host_index() + + raise ConnectionAbortedError( + f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})" + ) + + async def ping(self) -> int: + """Ping host to check if connection is established. + + Returns: + int: Response status code. + + Raises: + ServerConnectionError: If the response status code is not successful. + """ + request = Request(method=Method.GET, endpoint="/_api/collection") + resp = await self.send_request(request) + if resp.status_code in {401, 403}: + raise ServerConnectionError(resp, request, "Authentication failed.") + if not resp.is_success: + raise ServerConnectionError(resp, request, "Bad server response.") + return resp.status_code + + @abstractmethod + async def send_request(self, request: Request) -> Response: # pragma: no cover + """Send an HTTP request to the ArangoDB server. + + Args: + request (Request): HTTP request. + + Returns: + Response: HTTP response. + """ + raise NotImplementedError + + +class BasicConnection(BaseConnection): + """Connection to a specific ArangoDB database. + + Allows for basic authentication to be used (username and password). + + Args: + sessions (list): List of client sessions. + host_resolver (HostResolver): Host resolver. + http_client (HTTPClient): HTTP client. + db_name (str): Database name. + compression (CompressionManager | None): Compression manager. + auth (Auth | None): Authentication information. + """ + + def __init__( + self, + sessions: List[Any], + host_resolver: HostResolver, + http_client: HTTPClient, + db_name: str, + compression: Optional[CompressionManager] = None, + auth: Optional[Auth] = None, + ) -> None: + super().__init__(sessions, host_resolver, http_client, db_name, compression) + self._auth = auth + + async def send_request(self, request: Request) -> Response: + """Send an HTTP request to the ArangoDB server.""" + if request.data is not None and self._compression.needs_compression( + request.data + ): + request.data = self._compression.compress(request.data) + request.headers["content-encoding"] = self._compression.content_encoding() + + accept_encoding: str | None = self._compression.accept_encoding() + if accept_encoding is not None: + request.headers["accept-encoding"] = accept_encoding + + if self._auth: + request.auth = self._auth + + return await self.process_request(request) diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py new file mode 100644 index 0000000..2275f1b --- /dev/null +++ b/arangoasync/exceptions.py @@ -0,0 +1,84 @@ +from typing import Optional + +from arangoasync.request import Request +from arangoasync.response import Response + + +class ArangoError(Exception): + """Base class for all exceptions in python-arango-async.""" + + +class ArangoClientError(ArangoError): + """Base class for all client-related exceptions. + + Args: + msg (str): Error message. + + Attributes: + source (str): Source of the error (always set to "client") + message (str): Error message. + """ + + source = "client" + + def __init__(self, msg: str) -> None: + super().__init__(msg) + self.message = msg + + +class ArangoServerError(ArangoError): + """Base class for all server-related exceptions. + + Args: + resp (Response): HTTP response object. + request (Request): HTTP request object. + msg (str | None): Error message. + + Attributes: + source (str): Source of the error (always set to "server") + message (str): Error message. + url (str): URL of the request. + response (Response): HTTP response object. + request (Request): HTTP request object. + http_method (str): HTTP method of the request. + http_code (int): HTTP status code of the response. + http_headers (dict): HTTP headers of the response. + """ + + source = "server" + + def __init__( + self, resp: Response, request: Request, msg: Optional[str] = None + ) -> None: + msg = msg or resp.error_message or resp.status_text + self.error_message = resp.error_message + self.error_code = resp.error_code + if self.error_code is not None: + msg = f"[HTTP {resp.status_code}][ERR {self.error_code}] {msg}" + else: + msg = f"[HTTP {resp.status_code}] {msg}" + self.error_code = resp.status_code + super().__init__(msg) + self.message = msg + self.url = resp.url + self.response = resp + self.request = request + self.http_method = resp.method.name + self.http_code = resp.status_code + self.http_headers = resp.headers + + +class ConnectionAbortedError(ArangoClientError): + """The connection was aborted.""" + + +class ClientConnectionError(ArangoClientError): + """The request was unable to reach the server.""" + + +class JWTExpiredError(ArangoClientError): + """JWT token has expired.""" + + +class ServerConnectionError(ArangoServerError): + """Failed to connect to ArangoDB server.""" diff --git a/arangoasync/http.py b/arangoasync/http.py index b6ebfc0..e80dc91 100644 --- a/arangoasync/http.py +++ b/arangoasync/http.py @@ -7,8 +7,16 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from aiohttp import BaseConnector, BasicAuth, ClientSession, ClientTimeout, TCPConnector - +from aiohttp import ( + BaseConnector, + BasicAuth, + ClientSession, + ClientTimeout, + TCPConnector, + client_exceptions, +) + +from arangoasync.exceptions import ClientConnectionError from arangoasync.request import Request from arangoasync.response import Response @@ -74,10 +82,6 @@ class AioHTTPClient(HTTPClient): timeout (aiohttp.ClientTimeout | None): Client timeout settings. 300s total timeout by default for a complete request/response operation. read_bufsize (int): Size of read buffer (64KB default). - auth (aiohttp.BasicAuth | None): HTTP authentication helper. - Should be used for specifying authorization data in client API. - compression_threshold (int): Will compress requests to the server if the size - of the request body (in bytes) is at least the value of this option. .. _aiohttp: https://docs.aiohttp.org/en/stable/ @@ -88,8 +92,6 @@ def __init__( connector: Optional[BaseConnector] = None, timeout: Optional[ClientTimeout] = None, read_bufsize: int = 2**16, - auth: Optional[BasicAuth] = None, - compression_threshold: int = 1024, ) -> None: self._connector = connector or TCPConnector( keepalive_timeout=60, # timeout for connection reusing after releasing @@ -100,8 +102,6 @@ def __init__( connect=60, # max number of seconds for acquiring a pool connection ) self._read_bufsize = read_bufsize - self._auth = auth - self._compression_threshold = compression_threshold def create_session(self, host: str) -> ClientSession: """Return a new session given the base host URL. @@ -117,7 +117,6 @@ def create_session(self, host: str) -> ClientSession: base_url=host, connector=self._connector, timeout=self._timeout, - auth=self._auth, read_bufsize=self._read_bufsize, ) @@ -134,31 +133,40 @@ async def send_request( Returns: Response: HTTP response. + + Raises: + ClientConnectionError: If the request fails. """ - method = request.method - endpoint = request.endpoint - headers = request.headers - params = request.params - data = request.data - compress = data is not None and len(data) >= self._compression_threshold - - async with session.request( - method.name, - endpoint, - headers=headers, - params=params, - data=data, - compress=compress, - ) as response: - raw_body = await response.read() - return Response( - method=method, - url=str(response.real_url), - headers=response.headers, - status_code=response.status, - status_text=str(response.reason), - raw_body=raw_body, + + if request.auth is not None: + auth = BasicAuth( + login=request.auth.username, + password=request.auth.password, + encoding=request.auth.encoding, ) + else: + auth = None + + try: + async with session.request( + request.method.name, + request.endpoint, + headers=request.headers, + params=request.params, + data=request.data, + auth=auth, + ) as response: + raw_body = await response.read() + return Response( + method=request.method, + url=str(response.real_url), + headers=response.headers, + status_code=response.status, + status_text=str(response.reason), + raw_body=raw_body, + ) + except client_exceptions.ClientConnectionError as e: + raise ClientConnectionError(str(e)) from e DefaultHTTPClient = AioHTTPClient diff --git a/arangoasync/request.py b/arangoasync/request.py index 5971f92..0c183d5 100644 --- a/arangoasync/request.py +++ b/arangoasync/request.py @@ -6,6 +6,7 @@ from enum import Enum, auto from typing import Optional +from arangoasync.auth import Auth from arangoasync.typings import Params, RequestHeaders from arangoasync.version import __version__ @@ -30,14 +31,16 @@ class Request: endpoint (str): API endpoint. headers (dict | None): Request headers. params (dict | None): URL parameters. - data (str | None): Request payload. + data (bytes | None): Request payload. + auth (Auth | None): Authentication. Attributes: method (Method): HTTP method. endpoint (str): API endpoint. headers (dict | None): Request headers. params (dict | None): URL parameters. - data (str | None): Request payload. + data (bytes | None): Request payload. + auth (Auth | None): Authentication. """ __slots__ = ( @@ -46,6 +49,7 @@ class Request: "headers", "params", "data", + "auth", ) def __init__( @@ -54,13 +58,15 @@ def __init__( endpoint: str, headers: Optional[RequestHeaders] = None, params: Optional[Params] = None, - data: Optional[str] = None, + data: Optional[bytes] = None, + auth: Optional[Auth] = None, ) -> None: self.method: Method = method self.endpoint: str = endpoint self.headers: RequestHeaders = self._normalize_headers(headers) self.params: Params = self._normalize_params(params) - self.data: Optional[str] = data + self.data: Optional[bytes] = data + self.auth: Optional[Auth] = auth @staticmethod def _normalize_headers(headers: Optional[RequestHeaders]) -> RequestHeaders: diff --git a/arangoasync/resolver.py b/arangoasync/resolver.py new file mode 100644 index 0000000..1aa2bd8 --- /dev/null +++ b/arangoasync/resolver.py @@ -0,0 +1,115 @@ +__all__ = [ + "HostResolver", + "SingleHostResolver", + "RoundRobinHostResolver", + "DefaultHostResolver", + "get_resolver", +] + +from abc import ABC, abstractmethod +from typing import List, Optional + + +class HostResolver(ABC): + """Abstract base class for host resolvers. + + Args: + host_count (int): Number of hosts. + max_tries (int): Maximum number of attempts to try a host. + + Raises: + ValueError: If max_tries is less than host_count. + """ + + def __init__(self, host_count: int = 1, max_tries: Optional[int] = None) -> None: + max_tries = max_tries or host_count * 3 + if max_tries < host_count: + raise ValueError( + "The maximum number of attempts cannot be " + "lower than the number of hosts." + ) + self._host_count = host_count + self._max_tries = max_tries + self._index = 0 + + @abstractmethod + def get_host_index(self) -> int: # pragma: no cover + """Return the index of the host to use. + + Returns: + int: Index of the host. + """ + raise NotImplementedError + + def change_host(self) -> None: + """If there aer multiple hosts available, switch to the next one.""" + self._index = (self._index + 1) % self.host_count + + @property + def host_count(self) -> int: + """Return the number of hosts.""" + return self._host_count + + @property + def max_tries(self) -> int: + """Return the maximum number of attempts.""" + return self._max_tries + + +class SingleHostResolver(HostResolver): + """Single host resolver. Always returns the same host index.""" + + def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None: + super().__init__(host_count, max_tries) + + def get_host_index(self) -> int: + return self._index + + +class RoundRobinHostResolver(HostResolver): + """Round-robin host resolver. Changes host every time. + + Useful for bulk inserts or updates. + + Note: + Do not use this resolver for stream transactions. + Transaction IDs cannot be shared across different coordinators. + """ + + def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None: + super().__init__(host_count, max_tries) + self._index = -1 + + def get_host_index(self, indexes_to_filter: Optional[List[int]] = None) -> int: + self.change_host() + return self._index + + +DefaultHostResolver = SingleHostResolver + + +def get_resolver( + strategy: str, + host_count: int, + max_tries: Optional[int] = None, +) -> HostResolver: + """Return a host resolver based on the strategy. + + Args: + strategy (str): Resolver strategy. + host_count (int): Number of hosts. + max_tries (int): Maximum number of attempts to try a host. + + Returns: + HostResolver: Host resolver. + + Raises: + ValueError: If the strategy is not supported. + """ + if strategy == "roundrobin": + return RoundRobinHostResolver(host_count, max_tries) + if strategy == "single": + return SingleHostResolver(host_count, max_tries) + if strategy == "default": + return DefaultHostResolver(host_count, max_tries) + raise ValueError(f"Unsupported host resolver strategy: {strategy}") diff --git a/docs/conf.py b/docs/conf.py index 78b9956..6dae081 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,6 +24,7 @@ intersphinx_mapping = { "aiohttp": ("https://docs.aiohttp.org/en/stable/", None), + "jwt": ("https://pyjwt.readthedocs.io/en/stable/", None), } napoleon_google_docstring = True diff --git a/docs/specs.rst b/docs/specs.rst index eb39e72..29ba812 100644 --- a/docs/specs.rst +++ b/docs/specs.rst @@ -4,11 +4,23 @@ API Specification This page contains the specification for all classes and methods available in python-arango-async. +.. automodule:: arangoasync.auth + :members: + +.. automodule:: arangoasync.connection + :members: + +.. automodule:: arangoasync.exceptions + :members: ArangoError, ArangoClientError + .. automodule:: arangoasync.http :members: .. automodule:: arangoasync.request :members: +.. automodule:: arangoasync.resolver + :members: + .. automodule:: arangoasync.response :members: diff --git a/pyproject.toml b/pyproject.toml index cd09627..9d6b7b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "setuptools>=42", "aiohttp>=3.9", "multidict>=6.0", + "PyJWT>=2.8.0", ] [tool.setuptools.dynamic] diff --git a/tests/conftest.py b/tests/conftest.py index 9edb2a3..f335bed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ class GlobalData: url: str = None root: str = None password: str = None + sys_db_name: str = "_system" global_data = GlobalData() @@ -51,3 +52,8 @@ def root(): @pytest.fixture(autouse=False) def password(): return global_data.password + + +@pytest.fixture(autouse=False) +def sys_db_name(): + return global_data.sys_db_name diff --git a/tests/test_compression.py b/tests/test_compression.py new file mode 100644 index 0000000..26e9baa --- /dev/null +++ b/tests/test_compression.py @@ -0,0 +1,20 @@ +from arangoasync.compression import AcceptEncoding, DefaultCompressionManager + + +def test_DefaultCompressionManager_no_compression(): + manager = DefaultCompressionManager() + assert not manager.needs_compression("test") + assert not manager.needs_compression(b"test") + manager = DefaultCompressionManager(threshold=10) + assert not manager.needs_compression("test") + + +def test_DefaultCompressionManager_compress(): + manager = DefaultCompressionManager( + threshold=1, level=9, accept=AcceptEncoding.DEFLATE + ) + data = "a" * 10 + "b" * 10 + assert manager.needs_compression(data) + assert len(manager.compress(data)) < len(data) + assert manager.content_encoding() == "deflate" + assert manager.accept_encoding() == "deflate" diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..6c98081 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,44 @@ +import pytest + +from arangoasync.auth import Auth +from arangoasync.connection import BasicConnection +from arangoasync.exceptions import ServerConnectionError +from arangoasync.http import AioHTTPClient +from arangoasync.resolver import DefaultHostResolver + + +@pytest.mark.asyncio +async def test_BasicConnection_ping_failed(url, sys_db_name): + client = AioHTTPClient() + session = client.create_session(url) + resolver = DefaultHostResolver(1) + + connection = BasicConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + ) + + with pytest.raises(ServerConnectionError): + await connection.ping() + await session.close() + + +@pytest.mark.asyncio +async def test_BasicConnection_ping_success(url, sys_db_name, root, password): + client = AioHTTPClient() + session = client.create_session(url) + resolver = DefaultHostResolver(1) + + connection = BasicConnection( + sessions=[session], + host_resolver=resolver, + http_client=client, + db_name=sys_db_name, + auth=Auth(username=root, password=password), + ) + + status_code = await connection.ping() + assert status_code == 200 + await session.close() diff --git a/tests/test_http.py b/tests/test_http.py index a1047dc..e631586 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,10 +1,30 @@ import pytest -from aiohttp import BasicAuth -from arangoasync.http import AioHTTPClient +from arangoasync.auth import Auth +from arangoasync.exceptions import ClientConnectionError +from arangoasync.http import AioHTTPClient, DefaultHTTPClient from arangoasync.request import Method, Request +def test_DefaultHTTPClient(): + # This test is here in case to prevent accidental changes to the DefaultHTTPClient. + # Changed should be pushed only after the new HTTP client is covered by tests. + assert DefaultHTTPClient == AioHTTPClient + + +@pytest.mark.asyncio +async def test_AioHTTPClient_wrong_url(): + client = AioHTTPClient() + session = client.create_session("http://www.fasdfdsafadawe3523523532plmcom.tgzs") + request = Request( + method=Method.GET, + endpoint="/_api/version", + ) + with pytest.raises(ClientConnectionError): + await client.send_request(session, request) + await session.close() + + @pytest.mark.asyncio async def test_AioHTTPClient_simple_request(url): client = AioHTTPClient() @@ -18,18 +38,21 @@ async def test_AioHTTPClient_simple_request(url): assert response.url == f"{url}/_api/version" assert response.status_code == 401 assert response.status_text == "Unauthorized" + await session.close() @pytest.mark.asyncio async def test_AioHTTPClient_auth_pass(url, root, password): - client = AioHTTPClient(auth=BasicAuth(root, password)) + client = AioHTTPClient() session = client.create_session(url) request = Request( method=Method.GET, endpoint="/_api/version", + auth=Auth(username=root, password=password), ) response = await client.send_request(session, request) assert response.method == Method.GET assert response.url == f"{url}/_api/version" assert response.status_code == 200 assert response.status_text == "OK" + await session.close() diff --git a/tests/test_resolver.py b/tests/test_resolver.py new file mode 100644 index 0000000..5d53d72 --- /dev/null +++ b/tests/test_resolver.py @@ -0,0 +1,56 @@ +import pytest + +from arangoasync.resolver import ( + DefaultHostResolver, + RoundRobinHostResolver, + SingleHostResolver, + get_resolver, +) + + +def test_get_resolver(): + resolver = get_resolver("default", 1, 2) + assert isinstance(resolver, DefaultHostResolver) + + resolver = get_resolver("single", 2) + assert isinstance(resolver, SingleHostResolver) + + resolver = get_resolver("roundrobin", 3) + assert isinstance(resolver, RoundRobinHostResolver) + + with pytest.raises(ValueError): + get_resolver("invalid", 1) + + with pytest.raises(ValueError): + # max_tries cannot be less than host_count + get_resolver("roundrobin", 3, 1) + + +def test_SingleHostResolver(): + resolver = SingleHostResolver(1, 2) + assert resolver.host_count == 1 + assert resolver.max_tries == 2 + assert resolver.get_host_index() == 0 + assert resolver.get_host_index() == 0 + + resolver = SingleHostResolver(3) + assert resolver.host_count == 3 + assert resolver.max_tries == 9 + assert resolver.get_host_index() == 0 + resolver.change_host() + assert resolver.get_host_index() == 1 + resolver.change_host() + assert resolver.get_host_index() == 2 + resolver.change_host() + assert resolver.get_host_index() == 0 + + +def test_RoundRobinHostResolver(): + resolver = RoundRobinHostResolver(3) + assert resolver.host_count == 3 + assert resolver.get_host_index() == 0 + assert resolver.get_host_index() == 1 + assert resolver.get_host_index() == 2 + assert resolver.get_host_index() == 0 + resolver.change_host() + assert resolver.get_host_index() == 2