Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MSC3916 by adding a federation /download endpoint #17172

Merged
merged 19 commits into from
Jun 7, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 202 additions & 6 deletions synapse/media/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
#
#
import contextlib
import json
import logging
import os
import shutil
from contextlib import closing
from io import BytesIO
from types import TracebackType
from typing import (
IO,
Expand All @@ -31,14 +34,19 @@
BinaryIO,
Callable,
Generator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from uuid import uuid4

import attr
from zope.interface import implementer

from twisted.internet import defer, interfaces
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
Expand All @@ -49,6 +57,8 @@
from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer

from ..storage.databases.main.media_repository import LocalMedia
from ..types import JsonDict
from ._base import FileInfo, Responder
from .filepath import MediaFilePaths

Expand All @@ -58,6 +68,8 @@

logger = logging.getLogger(__name__)

CRLF = b"\r\n"


class MediaStorage:
"""Responsible for storing/fetching files from local sources.
Expand Down Expand Up @@ -202,15 +214,23 @@ async def finish() -> None:
)
raise exc

async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
async def fetch_media(
self,
file_info: FileInfo,
media_info: Optional[LocalMedia] = None,
federation: bool = False,
) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.

Args:
file_info
file_info: Metadata about the media file
media_info: Metadata about the media item
federation: Whether this file is being fetched for a federation request

Returns:
Returns a Responder if the file was found, otherwise None.
If the file was found returns a Responder (a Multipart Responder if the requested
file is for the federation /download endpoint), otherwise None.
"""
paths = [self._file_info_to_path(file_info)]

Expand All @@ -230,12 +250,19 @@ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
logger.debug("responding with local file %s", local_path)
return FileResponder(open(local_path, "rb"))
if federation:
assert media_info is not None
boundary = uuid4().hex.encode("ascii")
return MultipartResponder(
open(local_path, "rb"), media_info, boundary
)
else:
return FileResponder(open(local_path, "rb"))
logger.debug("local file %s did not exist", local_path)

for provider in self.storage_providers:
for path in paths:
res: Any = await provider.fetch(path, file_info)
res: Any = await provider.fetch(path, file_info, media_info, federation)
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
Expand Down Expand Up @@ -349,7 +376,7 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.

Args:
open_file: A file like object to be streamed ot the client,
open_file: A file like object to be streamed to the client,
is closed when finished streaming.
"""

Expand All @@ -370,6 +397,38 @@ def __exit__(
self.open_file.close()


class MultipartResponder(Responder):
"""Wraps an open file, formats the response according to MSC3916 and sends it to a
federation request.

Args:
open_file: A file like object to be streamed to the client,
is closed when finished streaming.
media_info: metadata about the media item
boundary: bytes to use for the multipart response boundary
"""

def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None:
self.open_file = open_file
self.media_info = media_info
self.boundary = boundary

def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
MultipartFileSender().beginFileTransfer(
self.open_file, consumer, self.media_info.media_type, {}, self.boundary
)
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
)

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.open_file.close()


class SpamMediaException(NotFoundError):
"""The media was blocked by a spam checker, so we simply 404 the request (in
the same way as if it was quarantined).
Expand Down Expand Up @@ -403,3 +462,140 @@ async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:

# We yield to the reactor by sleeping for 0 seconds.
await self.clock.sleep(0)


@implementer(interfaces.IProducer)
class MultipartFileSender:
"""
A producer that sends the contents of a file to a federation request in the format
outlined in MSC3916 - a multipart/format-data response where the first field is a
JSON object and the second is the requested file.

This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format
outlined above.
"""

CHUNK_SIZE = 2**14

lastSent = ""
deferred: Optional[defer.Deferred] = None

def beginFileTransfer(
self,
file: IO,
consumer: IConsumer,
file_content_type: str,
json_object: JsonDict,
boundary: bytes,
) -> Deferred:
"""
Begin transferring a file

Args:
file: The file object to read data from
consumer: The synapse request to write the data to
file_content_type: The content-type of the file
json_object: The JSON object to write to the first field of the response
boundary: bytes to be used as the multipart/form-data boundary

Returns: A deferred whose callback will be invoked when the file has
been completely written to the consumer. The last byte written to the
consumer is passed to the callback.
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
"""
self.file: Optional[IO] = file
self.consumer = consumer
self.json_field = json_object
self.json_field_written = False
self.content_type_written = False
self.file_content_type = file_content_type
self.boundary = boundary
self.deferred: Deferred = defer.Deferred()
self.consumer.registerProducer(self, False)

deferred = self.deferred
return deferred
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

def resumeProducing(self) -> None:
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
if not self.json_field_written:
self.consumer.write(CRLF + b"--" + self.boundary + b"" + CRLF)
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
content_type = Header(b"Content-Type", b"application/json")
self.consumer.write(bytes(content_type) + CRLF)
json_field = json.dumps(self.json_field)
json_bytes = json_field.encode("utf-8")
self.consumer.write(json_bytes)
self.consumer.write(CRLF + b"--" + self.boundary + b"" + CRLF)
self.json_field_written = True
chunk: Any = ""
if self.file:
if not self.content_type_written:
type = self.file_content_type.encode("utf-8")
content_type = Header(b"Content-Type", type)
self.consumer.write(bytes(content_type) + CRLF)
self.content_type_written = True
chunk = self.file.read(self.CHUNK_SIZE)
if not chunk:
self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
self.file = None
self.consumer.unregisterProducer()
if self.deferred:
self.deferred.callback(self.lastSent)
self.deferred = None
return

self.consumer.write(chunk)
self.lastSent = chunk[-1:]

def pauseProducing(self) -> None:
pass

def stopProducing(self) -> None:
if self.deferred:
self.deferred.errback(Exception("Consumer asked us to stop producing"))
self.deferred = None


class Header:
"""
`Header` This class is a tiny wrapper that produces
request headers. We can't use standard python header
class because it encodes unicode fields using =? bla bla ?=
encoding, which is correct, but no one in HTTP world expects
that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)

"""

def __init__(
self,
name: bytes,
value: Any,
params: Optional[List[Tuple[Any, Any]]] = None,
):
self.name = name
self.value = value
self.params = params or []

def add_param(self, name: Any, value: Any) -> None:
self.params.append((name, value))

def __bytes__(self) -> bytes:
with closing(BytesIO()) as h:
h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
if self.params:
for name, val in self.params:
h.write(b"; ")
h.write(escape(name).encode("us-ascii"))
h.write(b"=")
h.write(b'"' + escape(val).encode("utf-8") + b'"')
h.seek(0)
return h.read()


def escape(value: Union[str, bytes]) -> str:
"""
This function prevents header values from corrupting the request,
a newline in the file name parameter makes form-data request unreadable
for a majority of parsers. (stolen from treq.multipart)
"""
if isinstance(value, bytes):
value = value.decode("utf-8")
return value.replace("\r", "").replace("\n", "").replace('"', '\\"')