import base64
import datetime
import functools
import json
import re
import time
import warnings
from types import MethodType
from typing import Any, Callable, List, Optional, Tuple, Type, cast
from urllib.parse import urljoin
import requests
from requests.exceptions import ConnectionError, HTTPError
from ..exc import (
BadResponseError,
EndpointNotFoundError,
GatewayError,
InvalidCredentialsError,
LimitExceededError,
NotAuthenticatedError,
VersionMismatchError,
map_http_error,
)
from .options import APIClientOptions
# TODO: Protocol typing is available from Python 3.8
class UsesAPI:
"""
Typing for classes that are bound with APIClient,
set on self.api attribute
"""
api: "APIClient"
class JWTAuthToken:
def __init__(self, value: str) -> None:
self.value: str = value
try:
header, payload, signature = value.split(".")
self.header = json.loads(base64.b64decode(header + "=="))
self.payload = json.loads(base64.b64decode(payload + "=="))
except ValueError:
raise InvalidCredentialsError(
"Invalid authentication token. Verify whether actual token is provided "
"instead of its UUID."
)
@property
def is_expired(self) -> bool:
if "exp" not in self.header:
# No expiration time set
return False
return bool(datetime.datetime.utcnow().timestamp() >= self.header["exp"])
@property
def username(self) -> str:
return str(self.payload["login"])
[docs]class APIClient:
"""
Client for MWDB REST API that performs authentication and
low-level API request/response handling.
If you want to send arbitrary request to MWDB API, use :py:meth:`get`,
:py:meth:`post`, :py:meth:`put` and :py:meth:`delete` methods using
``MWDB.api`` property.
.. code-block:: python
mwdb = MWDB()
...
# Deletes object with given sha256
mwdb.api.delete(f'object/{sha256}')
"""
def __init__(
self,
_auth_token: Optional[str] = None,
autologin: bool = True,
**api_options: Any,
) -> None:
self.options: APIClientOptions = APIClientOptions(**api_options)
self.auth_token: Optional[JWTAuthToken] = None
# These state variables will be filled after
# successful authentication
self.username: Optional[str] = None
self.password: Optional[str] = None
self.api_key: Optional[str] = None
self._server_metadata: Optional[dict] = None
self.session: requests.Session = requests.Session()
from ..__version__ import __version__
self.session.headers[
"User-Agent"
] = f'mwdblib/{__version__} {str(self.session.headers["User-Agent"])}'
if _auth_token:
self.set_auth_token(_auth_token)
if autologin:
if self.options.api_key:
self.set_api_key(self.options.api_key)
elif self.options.username and self.options.password:
self.login(self.options.username, self.options.password)
@property
def server_metadata(self) -> dict:
"""
Information about MWDB Core server from ``/api/server`` endpoint.
"""
if self._server_metadata is None:
self._server_metadata = self.get("server", noauth=True)
return self._server_metadata
@property
def server_version(self) -> str:
"""
MWDB Core server version
"""
return str(self.server_metadata["server_version"])
@property
def logged_user(self) -> Optional[str]:
"""
Username of logged-in user or the owner of used API key.
Returns None if no credentials are provided
"""
return self.auth_token.username if self.auth_token else None
[docs] def supports_version(self, required_version: str) -> bool:
"""
Checks if server version is higher or equal than provided.
.. versionadded:: 4.1.0
"""
return parse_version(self.server_version) >= parse_version(required_version)
def set_auth_token(self, auth_key: str) -> None:
self.auth_token = JWTAuthToken(auth_key)
self.session.headers.update(
{"Authorization": f"Bearer {self.auth_token.value}"}
)
[docs] def login(self, username: str, password: str) -> None:
"""
Performs authentication using provided credentials
.. note:
It's not recommended to use this method directly.
Use :py:meth:`MWDB.login` instead.
:param username: Account username
:param password: Account password
"""
token = self.post(
"auth/login", json={"login": username, "password": password}, noauth=True
)["token"]
self.set_auth_token(token)
# Store credentials in API state
self.username = username
self.password = password
[docs] def set_api_key(self, api_key: str) -> None:
"""
Sets API key to be used for authorization
.. note:
It's not recommended to use this method directly.
Pass ``api_key`` argument to ``MWDB`` constructor.
:param api_key: API key to set
"""
self.set_auth_token(api_key)
# Store credentials in API state
self.api_key = api_key
if self.auth_token is not None:
self.username = self.auth_token.username
[docs] def logout(self) -> None:
"""
Removes authorization token from APIClient instance
"""
self.auth_token = None
self.session.headers.pop("Authorization")
def perform_request(
self, method: str, url: str, *args: Any, **kwargs: Any
) -> requests.models.Response:
try:
response = self.session.request(method, url, *args, **kwargs)
response.raise_for_status()
return response
except requests.exceptions.HTTPError as http_error:
mapped_error = map_http_error(http_error)
if mapped_error is None:
raise
raise mapped_error
[docs] def request(
self,
method: str,
url: str,
noauth: bool = False,
raw: bool = False,
*args: Any,
**kwargs: Any,
) -> Any:
"""
Sends request to MWDB API. This method can be used for accessing features that
are not directly supported by mwdblib library.
Other keyword arguments are the same as in requests library.
.. seealso::
Use functions specific for HTTP methods instead of passing
``method`` argument on your own:
- :py:meth:`APIClient.get`
- :py:meth:`APIClient.post`
- :py:meth:`APIClient.put`
- :py:meth:`APIClient.delete`
:param method: HTTP method
:param url: Relative url of API endpoint
:param noauth: |
Don't check if user is authenticated before sending request (default: False)
:param raw: Return raw response bytes instead of parsed JSON (default: False)
"""
# Check if authenticated
if not noauth and self.auth_token is None:
raise NotAuthenticatedError(
"API credentials for MWDB were not set, pass api_key parameter "
"to MWDB or call MWDB.login first"
)
# Set method name and request URL
url = urljoin(self.options.api_url, url)
# Pass verify_ssl setting to requests kwargs
kwargs["verify"] = self.options.verify_ssl
downtime_retries = self.options.max_downtime_retries
downtime_timeout = self.options.downtime_timeout
retry_on_downtime = self.options.retry_on_downtime
retry_idempotent = self.options.retry_idempotent
while True:
try:
response = self.perform_request(method, url, *args, **kwargs)
try:
return response.json() if not raw else response.content
except ValueError:
raise BadResponseError(
"Can't decode JSON response from server. "
"Probably APIClient.api_url points to the MWDB web app "
"instead of MWDB REST API."
)
except NotAuthenticatedError:
# Forget current auth_key
self.logout()
# If no password set: re-raise
if self.username is None or self.password is None:
raise
# Try to log in
self.login(self.username, self.password)
# Retry failed request...
except LimitExceededError as e:
if not self.options.obey_ratelimiter:
raise
# Assume that http_error is not Optional for LimitExceededError
http_error: HTTPError = cast(HTTPError, e.http_error)
if "Retry-After" not in http_error.response.headers:
# This should be exponential backoff, but we don't expect
# to see mwdb instances without retry-after headers anyway
retry_after = 60
else:
retry_after = int(http_error.response.headers["Retry-After"])
if self.options.emit_warnings:
warnings.warn(
f"Rate limit exceeded. Sleeping for a {retry_after} seconds."
)
time.sleep(retry_after)
# Retry failed request...
except (ConnectionError, GatewayError):
if (
not retry_on_downtime
or downtime_retries == 0
or (not retry_idempotent and method == "post")
):
raise
downtime_retries -= 1
if self.options.emit_warnings:
warnings.warn(
"Retrying request due to connectivity issues. "
f"Sleeping for {downtime_timeout} seconds."
)
time.sleep(downtime_timeout)
# Retry failed request...
def get(self, *args: Any, **kwargs: Any) -> Any:
return self.request("get", *args, **kwargs)
def post(self, *args: Any, **kwargs: Any) -> Any:
return self.request("post", *args, **kwargs)
def put(self, *args: Any, **kwargs: Any) -> Any:
return self.request("put", *args, **kwargs)
def delete(self, *args: Any, **kwargs: Any) -> Any:
return self.request("delete", *args, **kwargs)
[docs] @staticmethod
def requires(required_version: str, always_check_version: bool = False) -> Callable:
"""
Method decorator that provides server version requirement
and fallback to older implementation if available.
To optimize requests sent by CLI: first method is called
always if server version is not already available. If it
fails with EndpointNotFoundError, server version is fetched
and used to determine if fallback is available.
If your method fails on something different than missing endpoint,
you can check version always by enabling ``always_check_version``
flag.
"""
class VersionDependentMethod:
def __init__(self, main_method: Callable) -> None:
self.required_version = parse_version(required_version)
self.main_method = main_method
self.fallback_methods: List[Tuple[Tuple[int, ...], Callable]] = []
functools.update_wrapper(self, self.main_method)
def fallback(self, fallback_version: str) -> Callable:
"""
Registers fallback method
"""
def fallback_decorator(fallback_method: Callable) -> Callable:
self.fallback_methods = sorted(
self.fallback_methods
+ [(parse_version(fallback_version), fallback_method)],
key=lambda e: e[0],
reverse=True,
)
# Fallback method is still regular method
# that can be called directly
return fallback_method
return fallback_decorator
def __get__(self, obj: UsesAPI, objtype: Type = None) -> Callable:
# See also:
# https://docs.python.org/3/howto/descriptor.html#functions-and-methods
if obj is None:
# If descriptor got from class and not from instance
return self
return MethodType(self, obj)
def __call__(self, obj: UsesAPI, *args: Any, **kwargs: Any) -> Any:
# Check server version if available
server_version: Optional[Tuple[int, ...]]
if obj.api._server_metadata is not None or always_check_version:
server_version = parse_version(obj.api.server_version)
else:
server_version = None
# If version available and matches the requirement: call main method
if server_version is None or server_version >= self.required_version:
try:
return self.main_method(obj, *args, **kwargs)
except EndpointNotFoundError:
# On EndpointNotFoundError, check version requirement
server_version = parse_version(obj.api.server_version)
if server_version >= self.required_version:
# If matches the requirement: something wrong happened
raise
# If not: call fallback methods
fallback_version = self.required_version
for fallback_version, fallback_method in self.fallback_methods:
if server_version >= fallback_version:
return fallback_method(obj, *args, **kwargs)
else:
# If no fallback matched: raise VersionMismatchError
raise VersionMismatchError(
f"Required server version is "
f"'>={'.'.join(map(str, fallback_version))}' "
f"but '{obj.api.server_version}' is installed"
)
return VersionDependentMethod
def parse_version(version: str) -> Tuple[int, ...]:
"""
Parses server version into numerical tuple. Pre-release suffixes
e.g. -rc1 are accepted but ignored.
"""
match = re.match(
r"^(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:[-.]\w+)?",
version,
)
if not match:
raise ValueError(f"Version doesn't match to pattern: {version}")
return tuple(int(value) for value in match.groups())