Source code for andesite.discord

"""Utilities for the discord.py library.

If you're using the discord.py library, you can use the functions provided by
this module to your advantage. Instead of having to manually send the voice
server updates, you can call `add_voice_server_update_handler` once and be done
with it. All events received by discord.py will then automatically be forwarded
to the Andesite client.
Even further, you can use `connect_voice_channel` and `disconnect_voice_channel`
to easily connect to and disconnect from voice channels.
This works for both `discord.ext.commands.Bot` and normal `discord.Client`
instances without interfering.

Attributes:
    SOCKET_RESPONSE_HANDLERS_ATTR (str): Name of the attribute used to store the
        `SocketResponseHandler` instances in discord clients.
"""
import asyncio
import logging
from asyncio import Future
from typing import Any, Callable, Dict, Iterable, Optional, Set, TYPE_CHECKING, Tuple, Union, cast, overload

import andesite

if TYPE_CHECKING:
    try:
        from discord import Client, VoiceChannel, Guild
        # noinspection PyUnresolvedReferences
        from discord.ext.commands import Bot
        from discord.gateway import DiscordWebSocket
        from discord.state import ConnectionState
    except ImportError:
        pass

__all__ = ["get_discord_websocket",
           "update_voice_state", "connect_voice_channel", "disconnect_voice_channel",
           "AsyncMethodGroup", "get_async_method_group", "wrap_client_listener", "unwrap_client_listener",
           "SocketResponseHandler", "get_andesite_socket_response_handlers",
           "add_voice_server_update_handler", "remove_voice_server_update_handler",
           "compare_regions", "create_region_comparator"]

log = logging.getLogger(__name__)


[docs]def get_discord_websocket(client: Union["Client", "ConnectionState"], guild_id: int) -> "DiscordWebSocket": """Utility method to get access to discord.py's gateway websocket. Args: client: discord.py client. guild_id: Guild id whose websocket to get. """ try: # noinspection PyProtectedMember getter = client._connection._get_websocket except AttributeError: # noinspection PyProtectedMember,PyUnresolvedReferences getter = client._get_websocket return getter(guild_id)
[docs]async def update_voice_state(client: "Client", guild_id: int, channel_id: Optional[int]) -> None: """Update the voice state. Args: client: discord.py client. guild_id: Guild id to target channel_id: Channel id to connect to. If `None`, disconnect from the current channel. """ ws = get_discord_websocket(client, guild_id) await ws.voice_state(guild_id, channel_id)
@overload async def connect_voice_channel(client: "Client", channel: "VoiceChannel") -> None: ... @overload async def connect_voice_channel(client: "Client", guild: Union["Guild", int], channel_id: int) -> None: ...
[docs]async def connect_voice_channel(client: "Client", *args, **kwargs) -> None: """Connect to a voice channel. This function has two signatures, you can either call it with a `VoiceChannel`, or provide a guild / guild id and the voice channel id. """ total_args_len = len(args) + len(kwargs) # use this rather peculiar method to allow **kwargs and *args interchangeably if total_args_len == 1: await _connect_voice_channel_channel(client, *args, **kwargs) else: await _connect_voice_channel_guild_channel(client, *args, **kwargs)
async def _connect_voice_channel_channel(client: "Client", channel: "VoiceChannel") -> None: await update_voice_state(client, channel.guild.id, channel.id) async def _connect_voice_channel_guild_channel(client: "Client", guild: Union["Guild", int], channel_id: int) -> None: guild_id = guild if isinstance(guild, int) else guild.id await update_voice_state(client, guild_id, channel_id)
[docs]async def disconnect_voice_channel(client: "Client", guild: Union["Guild", int]) -> None: """Disconnect from the current voice channel.""" guild_id = guild if isinstance(guild, int) else guild.id await update_voice_state(client, guild_id, None)
[docs]class AsyncMethodGroup: """Group of async functions which act as a method. When called the instance calls all its functions. It doesn't pass the "self" parameter. """ __slots__ = ("methods",) methods: Set[Callable] def __init__(self, methods: Iterable[Callable]) -> None: self.methods = set(methods) def __call__(self, obj: Any, *args, **kwargs) -> Future: return asyncio.gather(func(*args, **kwargs) for func in self.methods)
[docs]def get_async_method_group(obj: Any, name: str) -> AsyncMethodGroup: """Get a method group for a method. If the method exists and is not already an `AsyncMethodGroup`, a new group containing the previous method takes its place. """ try: method = getattr(obj, name) except AttributeError: method = AsyncMethodGroup([]) else: if isinstance(method, AsyncMethodGroup): return method else: method = AsyncMethodGroup([method]) setattr(obj, name, method) return method
[docs]def wrap_client_listener(client: "Client", func: Callable, *, name: str = None) -> None: """Add a listener method to the discord client. Args: client: Client to add the listener to func: Listener function to add name: Custom name to use. Defaults to the name of the function. Be sure to include "on\_" if you set this. This makes it possible to add a listener if the client isn't a `discord.ext.commands.Bot`. This is achieved by adding the listener as a method to the client. If there already is a listener it is still called! See Also: `unwrap_client_listener` to remove it again. """ name = name or func.__name__ group = get_async_method_group(client, name) group.methods.add(func)
[docs]def unwrap_client_listener(client: "Client", func: Callable, *, name: str = None) -> None: """Remove a listener method from a discord client. Args: client: Client to remove listener from func: Listener function to remove name: Custom name to use. Defaults to the name of the function. Be sure to include "on\_" if you set this. """ name = name or func.__name__ group = get_async_method_group(client, name) group.methods.remove(func)
[docs]class SocketResponseHandler: """Socket response listener. An interface between discord.py clients and andesite clients to automatically send voice server updates. Attributes: discord_client (discord.Client): discord.py client that is listened to. andesite_client (WebSocketInterface): Andesite client to send the voice server update. """ __slots__ = ("discord_client", "andesite_client") discord_client: "Client" andesite_client: andesite.WebSocketInterface def __init__(self, discord_client: "Client", andesite_client: andesite.WebSocketInterface) -> None: self.discord_client = discord_client self.andesite_client = andesite_client
[docs] def add_listener(self) -> None: """Add the on_socket_response listener. If the handler is attached to a `Bot`, it uses the listener framework, otherwise it safely wraps the client handler. """ client = self.discord_client try: add_listener = cast("Bot", client).add_listener except AttributeError: wrap_client_listener(self.discord_client, self.on_socket_response) else: log.info(f"Adding socket response listener to {client}") add_listener(self.on_socket_response)
[docs] def remove_listener(self) -> None: """Remove the on_socket_response listener.""" client = self.discord_client try: remove_listener = cast("Bot", client).remove_listener except AttributeError: unwrap_client_listener(self.discord_client, self.on_socket_response) else: log.info(f"Removing socket response listener from {client}") remove_listener(self.on_socket_response)
[docs] async def on_socket_response(self, data: Dict[str, Any]) -> None: """Intercept voice server updates and send them to Andesite.""" try: key = data["t"] body = data["d"] except KeyError: return if key != "VOICE_SERVER_UPDATE": return guild_id = int(body["guild_id"]) ws = get_discord_websocket(self.discord_client, guild_id) session_id = ws.session_id if session_id: log.info(f"sending voice server update for guild {guild_id}") await self.andesite_client.voice_server_update(guild_id, session_id, body) else: log.debug(f"not sending voice server update for guild {guild_id} because session id missing.")
SOCKET_RESPONSE_HANDLERS_ATTR: str = "__andesite_socket_response_handlers__"
[docs]def get_andesite_socket_response_handlers(obj: Any) -> Dict[andesite.WebSocketInterface, SocketResponseHandler]: """Get the socket response handlers added to the discord client. If it doesn't exist, a new one is created and added to the object. """ try: handlers = getattr(obj, SOCKET_RESPONSE_HANDLERS_ATTR) except AttributeError: handlers = {} setattr(obj, SOCKET_RESPONSE_HANDLERS_ATTR, handlers) return handlers
[docs]def add_voice_server_update_handler(discord_client: "Client", andesite_client: andesite.WebSocketInterface) -> None: """Add a voice server update listener to the discord client. Args: discord_client: Discord client to add the listener to. andesite_client: Andesite web socket client to use to send the voice server update. This will listen to socket responses using the discord client and trigger `WebSocketInterface.voice_server_update`. """ handlers = get_andesite_socket_response_handlers(discord_client) if andesite_client in handlers: return handler = SocketResponseHandler(discord_client, andesite_client) handler.add_listener() handlers[andesite_client] = handler
[docs]def remove_voice_server_update_handler(discord_client: "Client", andesite_client: andesite.WebSocketInterface) -> None: """Remove the socket response handler added by `add_voice_server_update_handler`. Args: discord_client: Discord client to remove listener from. andesite_client: Andesite web socket client to use to send the voice server update. """ handlers = get_andesite_socket_response_handlers(discord_client) try: handler = handlers.pop(andesite_client) except KeyError: return handler.remove_listener()
def _split_region(region: str) -> Tuple[bool, str, Optional[str]]: """Parse a region string into its parts. Args: region: Region string to parse Returns: 3-tuple: - Whether the region is VIP or not - Country name - Part of country """ parts = region.lower().split("-") if parts[0] == "vip": del parts[0] vip = True else: vip = False country = parts.pop(0) try: country_part = parts.pop(0) except IndexError: country_part = None return vip, country, country_part
[docs]def compare_regions(a_region: str, b_region: str) -> int: """Compare two region names. The order of the provided regions is irrelevant. Args: a_region: First region name b_region: Second region name Returns: 0 if the regions can't be properly compared (ex: Node region unknown or guild id unknown). If the regions can be compared the result is the sum of the following points: - 2 points if both regions are in the same country - 1 point if both regions are in the same part of a country (ex: us_west) (This point is also awarded if both regions don't specify a country part) """ score: int = 0 _, a_country, a_country_part = _split_region(a_region) _, b_country, b_country_part = _split_region(b_region) if a_country == b_country: score += 2 if a_country_part == b_country_part: score += 1 return score
[docs]def create_region_comparator(discord_client: "Client", *, region_comp: Callable[[str, str], int] = None) -> andesite.RegionGuildComparator: """Create a region comparator which compares the guild region with the node region. You can use the created comparator for the `WebSocketPool`. Args: discord_client: Client to use to determine the guild region. region_comp: Specify the actual function which compares the guild region and the Andesite node region. Defaults to `compare_regions`. Returns: Region comparator which uses region_comp to compare guild region with node region. """ if region_comp is None: region_comp = compare_regions def comparator(guild_id: int, node_region: Optional[str]) -> int: if node_region is None or node_region == "unknown": return 0 guild: Optional[Guild] = discord_client.get_guild(guild_id) if guild is None: return 0 return region_comp(node_region, guild.region) return comparator