mirror of
https://github.com/kokarare1212/librespot-python.git
synced 2024-10-06 02:56:55 +02:00
cec130aa39
This commit fixes the style issues introduced in b779a6a
according to the output
from yapf.
Details: https://deepsource.io/gh/kokarare1212/librespot-python/transform/ef955754-0053-413b-8235-a6a87e8254d0/
257 lines
8.7 KiB
Python
257 lines
8.7 KiB
Python
from __future__ import annotations
|
|
from librespot.common import Utils
|
|
from librespot.core import Session, PacketsReceiver
|
|
from librespot.crypto import Packet
|
|
from librespot.mercury import JsonMercuryRequest, RawMercuryRequest, SubListener
|
|
from librespot.standard import BytesInputStream, BytesOutputStream, Closeable
|
|
from librespot.proto import Mercury, Pubsub
|
|
import json
|
|
import logging
|
|
import queue
|
|
import threading
|
|
import typing
|
|
|
|
|
|
class MercuryClient(PacketsReceiver.PacketsReceiver, Closeable):
|
|
_LOGGER: logging = logging.getLogger(__name__)
|
|
_MERCURY_REQUEST_TIMEOUT: int = 3
|
|
_seqHolder: int = 1
|
|
_seqHolderLock: threading.Condition = threading.Condition()
|
|
_callbacks: typing.Dict[int, Callback] = {}
|
|
_removeCallbackLock: threading.Condition = threading.Condition()
|
|
_subscriptions: typing.List[MercuryClient.InternalSubListener] = []
|
|
_subscriptionsLock: threading.Condition = threading.Condition()
|
|
_partials: typing.Dict[int, bytes] = {}
|
|
_session: Session = None
|
|
|
|
def __init__(self, session: Session):
|
|
self._session = session
|
|
|
|
def subscribe(self, uri: str, listener: SubListener) -> None:
|
|
response = self.send_sync(RawMercuryRequest.sub(uri))
|
|
if response.status_code != 200:
|
|
raise RuntimeError(response)
|
|
|
|
if len(response.payload) > 0:
|
|
for payload in response.payload:
|
|
sub = Pubsub.Subscription()
|
|
sub.ParseFromString(payload)
|
|
self._subscriptions.append(
|
|
MercuryClient.InternalSubListener(sub.uri, listener, True))
|
|
else:
|
|
self._subscriptions.append(
|
|
MercuryClient.InternalSubListener(uri, listener, True))
|
|
|
|
self._LOGGER.debug("Subscribed successfully to {}!".format(uri))
|
|
|
|
def unsubscribe(self, uri) -> None:
|
|
response = self.send_sync(RawMercuryRequest.unsub(uri))
|
|
if response.status_code != 200:
|
|
raise RuntimeError(response)
|
|
|
|
for subscription in self._subscriptions:
|
|
if subscription.matches(uri):
|
|
self._subscriptions.remove(subscription)
|
|
break
|
|
self._LOGGER.debug("Unsubscribed successfully from {}!".format(uri))
|
|
|
|
def send_sync(self, request: RawMercuryRequest) -> MercuryClient.Response:
|
|
callback = MercuryClient.SyncCallback()
|
|
seq = self.send(request, callback)
|
|
|
|
try:
|
|
resp = callback.wait_response()
|
|
if resp is None:
|
|
raise IOError(
|
|
"Request timeout out, {} passed, yet no response. seq: {}".
|
|
format(self._MERCURY_REQUEST_TIMEOUT, seq))
|
|
return resp
|
|
except queue.Empty as e:
|
|
raise IOError(e)
|
|
|
|
def send_sync_json(self, request: JsonMercuryRequest) -> typing.Any:
|
|
resp = self.send_sync(request.request)
|
|
if 200 <= resp.status_code < 300:
|
|
return json.loads(resp.payload[0])
|
|
raise MercuryClient.MercuryException(resp)
|
|
|
|
def send(self, request: RawMercuryRequest, callback) -> int:
|
|
buffer = BytesOutputStream()
|
|
|
|
seq: int
|
|
with self._seqHolderLock:
|
|
seq = self._seqHolder
|
|
self._seqHolder += 1
|
|
|
|
self._LOGGER.debug(
|
|
"Send Mercury request, seq: {}, uri: {}, method: {}".format(
|
|
seq, request.header.uri, request.header.method))
|
|
|
|
buffer.write_short(4)
|
|
buffer.write_int(seq)
|
|
|
|
buffer.write_byte(1)
|
|
buffer.write_short(1 + len(request.payload))
|
|
|
|
header_bytes = request.header.SerializeToString()
|
|
buffer.write_short(len(header_bytes))
|
|
buffer.write(header_bytes)
|
|
|
|
for part in request.payload:
|
|
buffer.write_short(len(part))
|
|
buffer.write(part)
|
|
|
|
cmd = Packet.Type.for_method(request.header.method)
|
|
self._session.send(cmd, buffer.buffer)
|
|
|
|
self._callbacks[seq] = callback
|
|
return seq
|
|
|
|
def dispatch(self, packet: Packet) -> None:
|
|
payload = BytesInputStream(packet.payload)
|
|
seq_length = payload.read_short()
|
|
if seq_length == 2:
|
|
seq = payload.read_short()
|
|
elif seq_length == 4:
|
|
seq = payload.read_int()
|
|
elif seq_length == 8:
|
|
seq = payload.read_long()
|
|
else:
|
|
raise RuntimeError("Unknown seq length: {}".format(seq_length))
|
|
|
|
flags = payload.read_byte()
|
|
parts = payload.read_short()
|
|
|
|
partial = self._partials.get(seq)
|
|
if partial is None or flags == 0:
|
|
partial = []
|
|
self._partials[seq] = partial
|
|
|
|
self._LOGGER.debug(
|
|
"Handling packet, cmd: 0x{}, seq: {}, flags: {}, parts: {}".format(
|
|
Utils.bytes_to_hex(packet.cmd), seq, flags, parts))
|
|
|
|
for i in range(parts):
|
|
size = payload.read_short()
|
|
buffer = payload.read(size)
|
|
partial.append(buffer)
|
|
self._partials[seq] = partial
|
|
|
|
if flags != b"\x01":
|
|
return
|
|
|
|
self._partials.pop(seq)
|
|
|
|
header = Mercury.Header()
|
|
header.ParseFromString(partial[0])
|
|
|
|
resp = MercuryClient.Response(header, partial)
|
|
|
|
if packet.is_cmd(Packet.Type.mercury_event):
|
|
dispatched = False
|
|
with self._subscriptionsLock:
|
|
for sub in self._subscriptions:
|
|
if sub.matches(header.uri):
|
|
sub.dispatch(resp)
|
|
dispatched = True
|
|
|
|
if not dispatched:
|
|
self._LOGGER.debug(
|
|
"Couldn't dispatch Mercury event seq: {}, uri: {}, code: {}, payload: {}"
|
|
.format(seq, header.uri, header.status_code, resp.payload))
|
|
elif packet.is_cmd(Packet.Type.mercury_req) or \
|
|
packet.is_cmd(Packet.Type.mercury_sub) or \
|
|
packet.is_cmd(Packet.Type.mercury_sub):
|
|
callback = self._callbacks.get(seq)
|
|
self._callbacks.pop(seq)
|
|
if callback is not None:
|
|
callback.response(resp)
|
|
else:
|
|
self._LOGGER.warning(
|
|
"Skipped Mercury response, seq: {}, uri: {}, code: {}".
|
|
format(seq, resp.uri, resp.status_code))
|
|
|
|
with self._removeCallbackLock:
|
|
self._removeCallbackLock.notify_all()
|
|
else:
|
|
self._LOGGER.warning(
|
|
"Couldn't handle packet, seq: {}, uri: {}, code: {}".format(
|
|
seq, header.uri, header.status_code))
|
|
|
|
def interested_in(self, uri: str, listener: SubListener) -> None:
|
|
self._subscriptions.append(
|
|
MercuryClient.InternalSubListener(uri, listener, False))
|
|
|
|
def not_interested_in(self, listener: SubListener) -> None:
|
|
try:
|
|
# noinspection PyTypeChecker
|
|
self._subscriptions.remove(listener)
|
|
except ValueError:
|
|
pass
|
|
|
|
def close(self) -> None:
|
|
if len(self._subscriptions) != 0:
|
|
for listener in self._subscriptions:
|
|
if listener.isSub:
|
|
self.unsubscribe(listener.uri)
|
|
else:
|
|
self.not_interested_in(listener.listener)
|
|
|
|
if len(self._callbacks) != 0:
|
|
with self._removeCallbackLock:
|
|
self._removeCallbackLock.wait(self._MERCURY_REQUEST_TIMEOUT)
|
|
|
|
self._callbacks.clear()
|
|
|
|
class Callback:
|
|
def response(self, response) -> None:
|
|
pass
|
|
|
|
class SyncCallback(Callback):
|
|
_reference = queue.Queue()
|
|
|
|
def response(self, response) -> None:
|
|
self._reference.put(response)
|
|
self._reference.task_done()
|
|
|
|
def wait_response(self) -> typing.Any:
|
|
return self._reference.get(
|
|
timeout=MercuryClient._MERCURY_REQUEST_TIMEOUT)
|
|
|
|
# class PubSubException(MercuryClient.MercuryException):
|
|
# pass
|
|
|
|
class InternalSubListener:
|
|
uri: str
|
|
listener: SubListener
|
|
isSub: bool
|
|
|
|
def __init__(self, uri: str, listener: SubListener, is_sub: bool):
|
|
self.uri = uri
|
|
self.listener = listener
|
|
self.isSub = is_sub
|
|
|
|
def matches(self, uri: str) -> bool:
|
|
return uri.startswith(self.uri)
|
|
|
|
def dispatch(self, resp: MercuryClient.Response) -> None:
|
|
self.listener.event(resp)
|
|
|
|
class MercuryException(Exception):
|
|
code: int
|
|
|
|
def __init__(self, response):
|
|
super("status: {}".format(response.status_code))
|
|
self.code = response.status_code
|
|
|
|
class Response:
|
|
uri: str
|
|
payload: typing.List[bytes]
|
|
status_code: int
|
|
|
|
def __init__(self, header: Mercury.Header,
|
|
payload: typing.List[bytes]):
|
|
self.uri = header.uri
|
|
self.status_code = header.status_code
|
|
self.payload = payload[1:]
|