Source code for airfs.storage.oss

"""Alibaba cloud OSS."""
from contextlib import contextmanager as _contextmanager
import re as _re

import oss2 as _oss  # type: ignore
from oss2.models import PartInfo as _PartInfo  # type: ignore
from oss2.exceptions import OssError as _OssError  # type: ignore

from airfs._core.io_base import memoizedmethod as _memoizedmethod
from airfs._core.exceptions import (
    ObjectNotFoundError as _ObjectNotFoundError,
    ObjectPermissionError as _ObjectPermissionError,
    ObjectNotASymlinkError as _ObjectNotASymlinkError,
    ObjectNotImplementedError as _ObjectNotImplementedError,
)
from airfs.io import (
    ObjectRawIOBase as _ObjectRawIOBase,
    ObjectBufferedIOBase as _ObjectBufferedIOBase,
    SystemBase as _SystemBase,
)

_ERROR_CODES = {
    403: _ObjectPermissionError,
    404: _ObjectNotFoundError,
    409: _ObjectPermissionError,
}


@_contextmanager
def _handle_oss_error():
    """Handle OSS exception and convert to class IO exceptions.

    Raises:
        OSError subclasses: IO error.
    """
    try:
        yield

    except _OssError as exception:
        if exception.status in _ERROR_CODES:
            raise _ERROR_CODES[exception.status](exception.details.get("Message", ""))
        raise


class _OSSSystem(_SystemBase):
    """OSS system.

    Args:
        storage_parameters (dict): OSS2 Auth keyword arguments and endpoint.
            This is generally OSS credentials and configuration.
        unsecure (bool): If True, disables TLS/SSL to improve transfer performance.
            But makes connection unsecure.
    """

    __slots__ = ("_unsecure", "_endpoint")

    SUPPORTS_SYMLINKS = True

    _CTIME_KEYS = ("Creation-Date", "creation_date")
    _MTIME_KEYS = ("Last-Modified", "last_modified")

    def __init__(self, storage_parameters=None, *args, **kwargs):
        try:
            storage_parameters = storage_parameters.copy()
            self._endpoint = storage_parameters.pop("endpoint")
        except (AttributeError, KeyError):
            raise ValueError('"endpoint" is required as "storage_parameters"')

        _SystemBase.__init__(
            self, storage_parameters=storage_parameters, *args, **kwargs
        )
        if self._unsecure:
            self._endpoint = self._endpoint.replace("https://", "http://")

    def copy(self, src, dst, other_system=None):
        """Copy an object of the same storage.

        Args:
            src (str): Path or URL.
            dst (str): Path or URL.
            other_system (airfs._core.io_system.SystemBase subclass): Unused.
        """
        copy_source = self.get_client_kwargs(src)
        copy_destination = self.get_client_kwargs(dst)
        with _handle_oss_error():
            bucket = self._get_bucket(copy_destination)
            bucket.copy_object(
                source_bucket_name=copy_source["bucket_name"],
                source_key=copy_source["key"],
                target_key=copy_destination["key"],
            )

    def get_client_kwargs(self, path):
        """Get base keyword arguments for the client for a specific path.

        Args:
            path (str): Absolute path or URL.

        Returns:
            dict: client args
        """
        bucket_name, key = self.split_locator(path)
        kwargs = dict(bucket_name=bucket_name)
        if key:
            kwargs["key"] = key
        return kwargs

    def _get_client(self):
        """OSS2 Auth client.

        Returns:
            oss2.Auth or oss2.StsAuth: client
        """
        return (
            _oss.StsAuth
            if "security_token" in self._storage_parameters
            else _oss.Auth if self._storage_parameters else _oss.AnonymousAuth
        )(**self._storage_parameters)

    def _get_roots(self):
        """Return URL roots for this storage.

        Returns:
            tuple of str or re.Pattern: URL roots
        """
        return (
            # OSS Scheme
            # - oss://<bucket>/<key>
            "oss://",
            # URL (With common aliyuncs.com endpoint):
            # - http://<bucket>.oss-<region>.aliyuncs.com/<key>
            # - https://<bucket>.oss-<region>.aliyuncs.com/<key>
            # Note: "oss-<region>.aliyuncs.com" may be replaced by another endpoint
            _re.compile(
                (r"^https?://[\w-]+.%s" % self._endpoint.split("//", 1)[1]).replace(
                    ".", r"\."
                )
            ),
        )

    def _get_bucket(self, client_kwargs):
        """Get bucket object.

        Returns:
            oss2.Bucket
        """
        return _oss.Bucket(
            self.client,
            endpoint=self._endpoint,
            bucket_name=client_kwargs["bucket_name"],
        )

    def islink(self, path=None, client_kwargs=None, header=None):
        """Returns True if the object is a symbolic link.

        Args:
            path (str): File path or URL.
            client_kwargs (dict): Client arguments.
            header (dict): Object header.

        Returns:
            bool: True if the object is Symlink.
        """
        header = self.head(path, client_kwargs, header)

        for key in ("x-oss-object-type", "type"):
            try:
                return header.pop(key) == "Symlink"
            except KeyError:
                continue
        return False

    def _head(self, client_kwargs):
        """Returns object HTTP header.

        Args:
            client_kwargs (dict): Client arguments.

        Returns:
            dict: HTTP header.
        """
        with _handle_oss_error():
            bucket = self._get_bucket(client_kwargs)

            if "key" in client_kwargs:
                return bucket.head_object(key=client_kwargs["key"]).headers

            return bucket.get_bucket_info().headers

    def _make_dir(self, client_kwargs):
        """Make a directory.

        Args:
            client_kwargs (dict): Client arguments.
        """
        with _handle_oss_error():
            bucket = self._get_bucket(client_kwargs)

            if "key" in client_kwargs:
                return bucket.put_object(key=client_kwargs["key"], data=b"")

            return bucket.create_bucket()

    def _remove(self, client_kwargs):
        """Remove an object.

        Args:
            client_kwargs (dict): Client arguments.
        """
        with _handle_oss_error():
            bucket = self._get_bucket(client_kwargs)

            if "key" in client_kwargs:
                return bucket.delete_object(key=client_kwargs["key"])

            return bucket.delete_bucket()

    @staticmethod
    def _model_to_dict(model, ignore):
        """Convert OSS model to dict.

        Args:
            model (oss2.models.RequestResult): Model.
            ignore (tuple of str): Keys to not insert to dict.

        Returns:
            dict: Model dict version.
        """
        return {
            attr: value
            for attr, value in model.__dict__.items()
            if not attr.startswith("_") and attr not in ignore
        }

    def _list_locators(self, max_results):
        """List locators.

        Args:
            max_results (int): The maximum results that should return the method.

        Yields:
            tuple: locator name str, locator header dict, has content bool
        """
        with _handle_oss_error():
            response = _oss.Service(self.client, endpoint=self._endpoint).list_buckets(
                max_keys=max_results or 100
            )

        for bucket in response.buckets:
            yield bucket.name, self._model_to_dict(bucket, ("name",)), True

    def _list_objects(self, client_kwargs, path, max_results, first_level):
        """List objects.

        Args:
            client_kwargs (dict): Client arguments.
            path (str): Path to list.
            max_results (int): The maximum results that should return the method.
            first_level (bool): If True, may only first level objects.

        Yields:
            tuple: object path str, object header dict, has content bool
        """
        prefix = self.split_locator(path)[1]
        index = len(prefix)
        kwargs = dict(prefix=prefix)
        if max_results:
            kwargs["max_keys"] = max_results

        bucket = self._get_bucket(client_kwargs)

        while True:
            with _handle_oss_error():
                response = bucket.list_objects(**kwargs)

            if not response.object_list:
                raise _ObjectNotFoundError(path=path)

            for obj in response.object_list:
                yield obj.key[index:], self._model_to_dict(obj, ("key",)), False

            if response.next_marker:
                client_kwargs["marker"] = response.next_marker
            else:
                break

    def read_link(self, path=None, client_kwargs=None, header=None):
        """Return the path linked by the symbolic link.

        Args:
            path (str): File path or URL.
            client_kwargs (dict): Client arguments.
            header (dict): Object header.

        Returns:
            str: Path.
        """
        if client_kwargs is None:
            client_kwargs = self.get_client_kwargs(path)
        try:
            key = client_kwargs["key"]
        except KeyError:
            raise _ObjectNotASymlinkError(path=path)

        with _handle_oss_error():
            return path.rsplit(key, 1)[0] + (
                self._get_bucket(client_kwargs).get_symlink(symlink_key=key).target_key
            )

    def symlink(self, target, path=None, client_kwargs=None):
        """Create a symbolic link to target.

        Args:
            target (str): Target path or URL.
            path (str): File path or URL.
            client_kwargs (dict): Client arguments.
        """
        if client_kwargs is None:
            client_kwargs = self.get_client_kwargs(path)
        target_client_kwargs = self.get_client_kwargs(target)

        if client_kwargs["bucket_name"] != target_client_kwargs["bucket_name"]:
            raise _ObjectNotImplementedError("Cross bucket symlinks are not supported")

        try:
            symlink_key = client_kwargs["key"]
            target_key = target_client_kwargs["key"]
        except KeyError:
            raise _ObjectNotImplementedError(
                "Symlinks to or from bucket root are not supported"
            )

        with _handle_oss_error():
            return self._get_bucket(client_kwargs).put_symlink(target_key, symlink_key)


[docs]class OSSRawIO(_ObjectRawIOBase): """Binary OSS Object I/O. Args: name (path-like object): URL or path to the file which will be opened. mode (str): The mode can be 'r', 'w', 'a' for reading (default), writing or appending. storage_parameters (dict): OSS2 Auth keyword arguments and endpoint. This is generally OSS credentials and configuration. unsecure (bool): If True, disables TLS/SSL to improve transfer performance. But makes connection unsecure. """ _SYSTEM_CLASS = _OSSSystem @property # type: ignore @_memoizedmethod def _bucket(self): """Bucket client. Returns: oss2.Bucket: Client. """ return self._system._get_bucket(self._client_kwargs) @property # type: ignore @_memoizedmethod def _key(self): """Object key. Returns: str: key. """ return self._client_kwargs["key"] def _read_range(self, start, end=0): """Read a range of bytes in stream. Args: start (int): Start stream position. end (int): End stream position. 0 To not specify the end. Returns: bytes: number of bytes read """ if start >= self._size: # EOF. Do not detect using 416 (Out of range) error, 200 returned. return bytes() with _handle_oss_error(): response = self._bucket.get_object( key=self._key, headers=dict( Range=self._http_range( start, end if end <= self._size else self._size, ) ), ) return response.read() def _readall(self): """Read and return all the bytes from the stream until EOF. Returns: bytes: Object content """ with _handle_oss_error(): return self._bucket.get_object(key=self._key).read() def _flush(self, buffer): """Flush the write buffers of the stream if applicable. Args: buffer (memoryview): Buffer content. """ with _handle_oss_error(): self._bucket.put_object(key=self._key, data=buffer.tobytes())
[docs]class OSSBufferedIO(_ObjectBufferedIOBase): """Buffered binary OSS Object I/O.""" __slots__ = ("_bucket", "_key", "_upload_id") _RAW_CLASS = OSSRawIO #: Minimal buffer_size in bytes (OSS multipart upload minimal part size) MINIMUM_BUFFER_SIZE = 102400 def __init__(self, *args, **kwargs): """Init. Args: name (path-like object): URL or path to the file which will be opened. mode (str): The mode can be 'r', 'w' for reading (default) or writing buffer_size (int): The size of buffer. max_buffers (int): The maximum number of buffers to preload in read mode or awaiting flush in "write" mode. 0 for no limit. max_workers (int): The maximum number of threads that can be used to execute the given calls. storage_parameters (dict): OSS2 Auth keyword arguments and endpoint. This is generally OSS credentials and configuration. unsecure (bool): If True, disables TLS/SSL to improve transfer performance. But makes connection unsecure. """ _ObjectBufferedIOBase.__init__(self, *args, **kwargs) self._bucket = self._raw._bucket self._key = self._raw._key self._upload_id = None def _flush(self): """Flush the write buffers of the stream.""" if self._upload_id is None: with _handle_oss_error(): self._upload_id = self._bucket.init_multipart_upload( self._key ).upload_id response = self._workers.submit( self._bucket.upload_part, key=self._key, upload_id=self._upload_id, part_number=self._seek, data=self._get_buffer().tobytes(), ) self._write_futures.append(dict(response=response, part_number=self._seek)) def _close_writable(self): """Close the object in "write" mode.""" parts = [ _PartInfo( part_number=future["part_number"], etag=future["response"].result().etag ) for future in self._write_futures ] with _handle_oss_error(): try: self._bucket.complete_multipart_upload( key=self._key, upload_id=self._upload_id, parts=parts ) except _OssError: self._bucket.abort_multipart_upload( key=self._key, upload_id=self._upload_id ) raise