Add search support

This commit is contained in:
kokarare1212 2021-09-12 18:40:51 +09:00
parent 82a3ad2047
commit e7875145ad
No known key found for this signature in database
GPG Key ID: 9FB32C7C7D874F7A
2 changed files with 116 additions and 7 deletions

View File

@ -1,4 +1,7 @@
from __future__ import annotations from __future__ import annotations
import urllib.parse
from Cryptodome import Random from Cryptodome import Random
from Cryptodome.Hash import HMAC, SHA1 from Cryptodome.Hash import HMAC, SHA1
from Cryptodome.PublicKey import RSA from Cryptodome.PublicKey import RSA
@ -572,8 +575,8 @@ class MessageType(enum.Enum):
class Session(Closeable, MessageListener, SubListener): class Session(Closeable, MessageListener, SubListener):
cipher_pair: typing.Union[CipherPair, None] cipher_pair: typing.Union[CipherPair, None]
country_code: str = "EN"
connection: typing.Union[ConnectionHolder, None] connection: typing.Union[ConnectionHolder, None]
country_code: str
logger = logging.getLogger("Librespot:Session") logger = logging.getLogger("Librespot:Session")
scheduled_reconnect: typing.Union[sched.Event, None] = None scheduled_reconnect: typing.Union[sched.Event, None] = None
scheduler = sched.scheduler(time.time) scheduler = sched.scheduler(time.time)
@ -594,6 +597,7 @@ class Session(Closeable, MessageListener, SubListener):
__keys: DiffieHellman __keys: DiffieHellman
__mercury_client: MercuryClient __mercury_client: MercuryClient
__receiver: typing.Union[Receiver, None] __receiver: typing.Union[Receiver, None]
__search: typing.Union[SearchManager, None]
__server_key = b"\xac\xe0F\x0b\xff\xc20\xaf\xf4k\xfe\xc3\xbf\xbf\x86=" \ __server_key = b"\xac\xe0F\x0b\xff\xc20\xaf\xf4k\xfe\xc3\xbf\xbf\x86=" \
b"\xa1\x91\xc6\xcc3l\x93\xa1O\xb3\xb0\x16\x12\xac\xacj" \ b"\xa1\x91\xc6\xcc3l\x93\xa1O\xb3\xb0\x16\x12\xac\xacj" \
b"\xf1\x80\xe7\xf6\x14\xd9B\x9d\xbe.4fC\xe3b\xd22z\x1a" \ b"\xf1\x80\xe7\xf6\x14\xd9B\x9d\xbe.4fC\xe3b\xd22z\x1a" \
@ -658,6 +662,7 @@ class Session(Closeable, MessageListener, SubListener):
self.__content_feeder = PlayableContentFeeder(self) self.__content_feeder = PlayableContentFeeder(self)
self.__cache_manager = CacheManager(self) self.__cache_manager = CacheManager(self)
self.__dealer_client = DealerClient(self) self.__dealer_client = DealerClient(self)
self.__search = SearchManager(self)
self.__event_service = EventService(self) self.__event_service = EventService(self)
self.__auth_lock_bool = False self.__auth_lock_bool = False
self.__auth_lock.notify_all() self.__auth_lock.notify_all()
@ -880,6 +885,9 @@ class Session(Closeable, MessageListener, SubListener):
self.logger.debug("Parsed product info: {}".format( self.logger.debug("Parsed product info: {}".format(
self.__user_attributes)) self.__user_attributes))
def preferred_locale(self) -> str:
return self.__inner.preferred_locale
def reconnect(self) -> None: def reconnect(self) -> None:
""" """
Reconnect to the Spotify Server Reconnect to the Spotify Server
@ -904,6 +912,12 @@ class Session(Closeable, MessageListener, SubListener):
def reconnecting(self) -> bool: def reconnecting(self) -> bool:
return not self.__closing and not self.__closed and self.connection is None return not self.__closing and not self.__closed and self.connection is None
def search(self) -> SearchManager:
self.__wait_auth_lock()
if self.__search is None:
raise RuntimeError("Session isn't authenticated!")
return self.__search
def send(self, cmd: bytes, payload: bytes): def send(self, cmd: bytes, payload: bytes):
""" """
Send data to socket using send_unchecked Send data to socket using send_unchecked
@ -927,6 +941,9 @@ class Session(Closeable, MessageListener, SubListener):
raise RuntimeError("Session isn't authenticated!") raise RuntimeError("Session isn't authenticated!")
return self.__token_provider return self.__token_provider
def username(self):
return self.__ap_welcome.canonical_username
def __authenticate_partial(self, def __authenticate_partial(self,
credential: Authentication.LoginCredentials, credential: Authentication.LoginCredentials,
remove_lock: bool) -> None: remove_lock: bool) -> None:
@ -1534,10 +1551,10 @@ class Session(Closeable, MessageListener, SubListener):
elif cmd == Packet.Type.pong_ack: elif cmd == Packet.Type.pong_ack:
continue continue
elif cmd == Packet.Type.country_code: elif cmd == Packet.Type.country_code:
self.__session.country_code = packet.payload.decode() self.__session.__country_code = packet.payload.decode()
self.__session.logger.info( self.__session.logger.info(
"Received country_code: {}".format( "Received country_code: {}".format(
self.__session.country_code)) self.__session.__country_code))
elif cmd == Packet.Type.license_version: elif cmd == Packet.Type.license_version:
license_version = io.BytesIO(packet.payload) license_version = io.BytesIO(packet.payload)
license_id = struct.unpack(">h", license_id = struct.unpack(">h",
@ -1576,6 +1593,98 @@ class Session(Closeable, MessageListener, SubListener):
Keyexchange.ErrorCode.Name(login_failed.error_code)) Keyexchange.ErrorCode.Name(login_failed.error_code))
class SearchManager:
base_url = "hm://searchview/km/v4/search/"
__session: Session
def __init__(self, session: Session):
self.__session = session
def request(self, request: SearchRequest) -> typing.Any:
if request.get_username() == "":
request.set_username(self.__session.username())
if request.get_country() == "":
request.set_country(self.__session.country_code)
if request.get_locale() == "":
request.set_locale(self.__session.preferred_locale())
response = self.__session.mercury().send_sync(RawMercuryRequest.new_builder()
.set_method("GET").set_uri(request.build_url()).build())
if response.status_code != 200:
raise SearchManager.SearchException(response.status_code)
return json.loads(response.payload)
class SearchException(Exception):
def __init__(self, status_code: int):
super().__init__("Search failed with code {}.".format(status_code))
class SearchRequest:
query: typing.Final[str]
__catalogue = ""
__country = ""
__image_size = ""
__limit = 10
__locale = ""
__username = ""
def __init__(self, query: str):
self.query = query
if query == "":
raise TypeError
def build_url(self) -> str:
url = SearchManager.base_url + urllib.parse.quote(self.query)
url += "?entityVersion=2"
url += "&catalogue=" + urllib.parse.quote(self.__catalogue)
url += "&country=" + urllib.parse.quote(self.__country)
url += "&imageSize=" + urllib.parse.quote(self.__image_size)
url += "&limit=" + str(self.__limit)
url += "&locale=" + urllib.parse.quote(self.__locale)
url += "&username=" + urllib.parse.quote(self.__username)
return url
def get_catalogue(self) -> str:
return self.__catalogue
def get_country(self) -> str:
return self.__country
def get_image_size(self) -> str:
return self.__image_size
def get_limit(self) -> int:
return self.__limit
def get_locale(self) -> str:
return self.__locale
def get_username(self) -> str:
return self.__username
def set_catalogue(self, catalogue: str) -> SearchManager.SearchRequest:
self.__catalogue = catalogue
return self
def set_country(self, country: str) -> SearchManager.SearchRequest:
self.__country = country
return self
def set_image_size(self, image_size: str) -> SearchManager.SearchRequest:
self.__image_size = image_size
return self
def set_limit(self, limit: int) -> SearchManager.SearchRequest:
self.__limit = limit
return self
def set_locale(self, locale: str) -> SearchManager.SearchRequest:
self.__locale = locale
return self
def set_username(self, username: str) -> SearchManager.SearchRequest:
self.__username = username
return self
class TokenProvider: class TokenProvider:
logger = logging.getLogger("Librespot:TokenProvider") logger = logging.getLogger("Librespot:TokenProvider")
token_expire_threshold = 10 token_expire_threshold = 10

View File

@ -182,7 +182,7 @@ class MercuryClient(Closeable, PacketsReceiver):
def send_sync_json(self, request: JsonMercuryRequest) -> typing.Any: def send_sync_json(self, request: JsonMercuryRequest) -> typing.Any:
response = self.send_sync(request.request) response = self.send_sync(request.request)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
return json.loads(response.payload[0]) return json.loads(response.payload)
raise MercuryClient.MercuryException(response) raise MercuryClient.MercuryException(response)
def subscribe(self, uri: str, listener: SubListener) -> None: def subscribe(self, uri: str, listener: SubListener) -> None:
@ -266,14 +266,14 @@ class MercuryClient(Closeable, PacketsReceiver):
class Response: class Response:
uri: str uri: str
payload: typing.List[bytes] payload: bytes
status_code: int status_code: int
def __init__(self, header: Mercury.Header, def __init__(self, header: Mercury.Header,
payload: typing.List[bytes]): payload: list[bytes]):
self.uri = header.uri self.uri = header.uri
self.status_code = header.status_code self.status_code = header.status_code
self.payload = payload[1:] self.payload = b"".join(payload[1:])
class SyncCallback(Callback): class SyncCallback(Callback):
__reference = queue.Queue() __reference = queue.Queue()