1
Fork 0
mirror of https://github.com/RGBCube/GitHubWrapper synced 2025-05-25 18:25:09 +00:00

Update main.py to add typehints

This commit is contained in:
NextChai 2022-04-30 02:26:00 -04:00
parent d482cf1bbe
commit 983a7cb094

View file

@ -1,26 +1,45 @@
#== main.py ==# # == main.py ==#
from __future__ import annotations from __future__ import annotations
from datetime import datetime
__all__ = ( __all__ = ("GHClient",)
'GHClient',
)
import asyncio import asyncio
import functools import functools
from typing import Any, Union, List, Dict
import aiohttp import aiohttp
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Literal,
Any,
Coroutine,
Dict,
Generator,
Optional,
Union,
List,
overload,
TypeVar,
)
from typing_extensions import Self, ParamSpec, Concatenate
from . import exceptions from . import exceptions
from .cache import ObjectCache from .cache import ObjectCache
from .http import http from .http import http
from .objects import Gist, Issue, Organization, Repository, User, File from .objects import Gist, Issue, Organization, Repository, User, File
T = TypeVar("T")
P = ParamSpec("P")
class GHClient: class GHClient:
_auth = None if TYPE_CHECKING:
has_started = False http: http
http: http
has_started: bool = False
def __init__( def __init__(
self, self,
*, *,
@ -28,7 +47,7 @@ class GHClient:
token: Union[str, None] = None, token: Union[str, None] = None,
user_cache_size: int = 30, user_cache_size: int = 30,
repo_cache_size: int = 15, repo_cache_size: int = 15,
custom_headers: dict[str, Union[str, int]] = {} custom_headers: dict[str, Union[str, int]] = {},
): ):
"""The main client, used to start most use-cases.""" """The main client, used to start most use-cases."""
self._headers = custom_headers self._headers = custom_headers
@ -40,35 +59,57 @@ class GHClient:
self.token = token self.token = token
self._auth = aiohttp.BasicAuth(username, token) self._auth = aiohttp.BasicAuth(username, token)
def __await__(self) -> 'GHClient': # Cache manegent
self._cache(type='user')(self.get_self) # type: ignore
self._cache(type='user')(self.get_user) # type: ignore
self._cache(type='repo')(self.get_repo) # type: ignore
def __call__(self, *args: Any, **kwargs: Any) -> Coroutine[Any, Any, Self]:
return self.start(*args, **kwargs)
def __await__(self) -> Generator[Any, Any, Self]:
return self.start().__await__() return self.start().__await__()
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} has_auth={bool(self._auth)}>' return f"<{self.__class__.__name__} has_auth={bool(self._auth)}>"
def __del__(self): def __del__(self):
asyncio.create_task(self.http.session.close()) asyncio.create_task(
self.http.session.close(), name="cleanup-session-github-api-wrapper"
)
def check_limits(self, as_dict: bool = False) -> dict[str, str | int] | list[str]: @overload
def check_limits(self, as_dict: Literal[True] = True) -> Dict[str, Union[str, int]]:
...
@overload
def check_limits(self, as_dict: Literal[False] = False) -> List[str]:
...
def check_limits(
self, as_dict: bool = False
) -> Union[Dict[str, Union[str, int]], List[str]]:
if not self.has_started: if not self.has_started:
raise exceptions.NotStarted raise exceptions.NotStarted
if not as_dict: if not as_dict:
output = [] output: List[str] = []
for key, value in self.http.session._rates._asdict().items(): for key, value in self.http.session._rates._asdict().items(): # type: ignore
output.append(f'{key} : {value}') output.append(f"{key} : {value}")
return output return output
return self.http.session._rates
return self.http.session._rates # type: ignore
async def update_auth(self, username: str, token: str) -> None: async def update_auth(self, username: str, token: str) -> None:
"""Allows you to input auth information after instantiating the client.""" """Allows you to input auth information after instantiating the client."""
#check if username and token is valid # check if username and token is valid
await self.http.update_auth(username=username, token=token) await self.http.update_auth(username=username, token=token)
try: try:
await self.http.get_self() await self.http.get_self()
except exceptions.InvalidToken as exc: except exceptions.InvalidToken as exc:
raise exceptions.InvalidToken from exc raise exceptions.InvalidToken from exc
async def start(self) -> 'GHClient': async def start(self) -> Self:
"""Main entry point to the wrapper, this creates the ClientSession.""" """Main entry point to the wrapper, this creates the ClientSession."""
if self.has_started: if self.has_started:
raise exceptions.AlreadyStarted raise exceptions.AlreadyStarted
@ -83,78 +124,103 @@ class GHClient:
self.has_started = True self.has_started = True
return self return self
def _cache(*args, **kwargs): def _cache(
target_type = kwargs.get('type') self: Self, *, type: str
def wrapper(func): ) -> Callable[
[Callable[Concatenate[Self, P], Awaitable[T]]],
Callable[Concatenate[Self, P], Awaitable[Optional[Union[T, User, Repository]]]],
]:
def wrapper(
func: Callable[Concatenate[Self, P], Awaitable[T]]
) -> Callable[
Concatenate[Self, P], Awaitable[Optional[Union[T, User, Repository]]]
]:
@functools.wraps(func) @functools.wraps(func)
async def wrapped(self, *args, **kwargs): async def wrapped(
if target_type == 'User': self: Self, *args: P.args, **kwargs: P.kwargs
if (obj := self._user_cache.get(kwargs.get('user'))): ) -> Optional[Union[T, User, Repository]]:
if type == "user":
if obj := self._user_cache.get(kwargs.get("user")):
return obj return obj
else:
res = await func(self, *args, **kwargs) user: User = await func(self, *args, **kwargs) # type: ignore
self._user_cache[kwargs.get('user')] = res self._user_cache[kwargs.get("user")] = user
return res return user
if target_type == 'Repo': if type == "repo":
if (obj := self._repo_cache.get(kwargs.get('repo'))): if obj := self._repo_cache.get(kwargs.get("repo")):
return obj return obj
else:
res = await func(self, *args, **kwargs) repo: Repository = await func(self, *args, **kwargs) # type: ignore
self._repo_cache[kwargs.get('repo')] = res self._repo_cache[kwargs.get("repo")] = repo
return res return repo
return wrapped return wrapped
return wrapper return wrapper
#@_cache(type='User') # @_cache(type='User')
async def get_self(self) -> User: async def get_self(self) -> User:
"""Returns the authenticated User object.""" """Returns the authenticated User object."""
if self._auth: if self._auth:
return User(await self.http.get_self(), self.http.session) return User(await self.http.get_self(), self.http)
else: else:
raise exceptions.NoAuthProvided raise exceptions.NoAuthProvided
@_cache(type='User')
async def get_user(self, *, user: str) -> User: async def get_user(self, *, user: str) -> User:
"""Fetch a Github user from their username.""" """Fetch a Github user from their username."""
return User(await self.http.get_user(user), self.http.session) return User(await self.http.get_user(user), self.http)
@_cache(type='Repo')
async def get_repo(self, *, owner: str, repo: str) -> Repository: async def get_repo(self, *, owner: str, repo: str) -> Repository:
"""Fetch a Github repository from it's name.""" """Fetch a Github repository from it's name."""
return Repository(await self.http.get_repo(owner, repo), self.http.session) return Repository(await self.http.get_repo(owner, repo), self.http)
async def get_issue(self, *, owner: str, repo: str, issue: int) -> Issue: async def get_issue(self, *, owner: str, repo: str, issue: int) -> Issue:
"""Fetch a Github Issue from it's name.""" """Fetch a Github Issue from it's name."""
return Issue(await self.http.get_repo_issue(owner, repo, issue), self.http.session) return Issue(
await self.http.get_repo_issue(owner, repo, issue), self.http
)
async def create_repo(self, name: str, description: str = 'Repository created using Github-Api-Wrapper.', public: bool = False,gitignore: str = None, license: str = None) -> Repository: async def create_repo(
return Repository(await self.http.create_repo(name,description,public,gitignore,license), self.http.session) self,
name: str,
description: str = "Repository created using Github-Api-Wrapper.",
public: bool = False,
gitignore: Optional[str] = None,
license: Optional[str] = None,
) -> Repository:
return Repository(
await self.http.create_repo(name, description, public, gitignore, license),
self.http,
)
async def delete_repo(self, repo: str= None, owner: str = None) -> None: async def delete_repo(self, repo: str, owner: str) -> Optional[str]:
"""Delete a Github repository, requires authorisation.""" """Delete a Github repository, requires authorisation."""
owner = owner or self.username owner = owner or self.username
return await self.http.delete_repo(owner, repo) return await self.http.delete_repo(owner, repo)
async def get_gist(self, gist: int) -> Gist: async def get_gist(self, gist: int) -> Gist:
"""Fetch a Github gist from it's id.""" """Fetch a Github gist from it's id."""
return Gist(await self.http.get_gist(gist), self.http.session) return Gist(await self.http.get_gist(gist), self.http)
async def create_gist(self, *, files: List[File], description: str, public: bool) -> Gist: async def create_gist(
self, *, files: List[File], description: str, public: bool
) -> Gist:
"""Creates a Gist with the given files, requires authorisation.""" """Creates a Gist with the given files, requires authorisation."""
return Gist(await self.http.create_gist(files=files, description=description, public=public), self.http.session) return Gist(
await self.http.create_gist(
files=files, description=description, public=public
),
self.http,
)
async def delete_gist(self, gist: int) -> None: async def delete_gist(self, gist: int) -> Optional[str]:
"""Delete a Github gist, requires authorisation.""" """Delete a Github gist, requires authorisation."""
return await self.http.delete_gist(gist) return await self.http.delete_gist(gist)
async def get_org(self, org: str) -> Organization: async def get_org(self, org: str) -> Organization:
"""Fetch a Github organization from it's name.""" """Fetch a Github organization from it's name."""
return Organization(await self.http.get_org(org), self.http.session) return Organization(await self.http.get_org(org), self.http)
async def latency(self) -> float: async def latency(self) -> float:
"""Returns the latency of the client.""" """Returns the latency of the client."""
return await self.http.latency() return await self.http.latency()