1
Fork 0
mirror of https://github.com/RGBCube/GitHubWrapper synced 2025-05-18 23:15:09 +00:00

refactor w/ black

This commit is contained in:
NextChai 2022-04-30 02:27:16 -04:00
parent cc3cde89c8
commit b474492637
8 changed files with 187 additions and 165 deletions

View file

@ -1,4 +1,4 @@
#== __init__.py ==#
# == __init__.py ==#
__title__ = 'Github-Api-Wrapper'
__authors__ = 'VarMonke', 'sudosnok'

View file

@ -1,4 +1,4 @@
#== cache.py ==#
# == cache.py ==#
from __future__ import annotations
@ -6,9 +6,7 @@ from collections import deque
from collections.abc import MutableMapping
from typing import Any, Deque, Tuple, TypeVar
__all__: Tuple[str, ...] = (
'ObjectCache',
)
__all__: Tuple[str, ...] = ('ObjectCache',)
K = TypeVar('K')
@ -50,6 +48,7 @@ class _BaseCache(MutableMapping[K, V]):
class ObjectCache(_BaseCache[K, V]):
"""This adjusts the typehints to reflect Github objects."""
def __getitem__(self, __k: K) -> V:
index = self._lru_keys.index(__k)
target = self._lru_keys[index]

View file

@ -1,4 +1,4 @@
#== exceptions.py ==#
# == exceptions.py ==#
import datetime
from typing import TYPE_CHECKING
@ -24,108 +24,141 @@ __all__ = (
'IssueNotFound',
'OrganizationNotFound',
'RepositoryAlreadyExists',
)
)
class APIError(Exception):
"""Base level exceptions raised by errors related to any API request or call."""
pass
class HTTPException(Exception):
"""Base level exceptions raised by errors related to HTTP requests."""
pass
class ClientException(Exception):
"""Base level exceptions raised by errors related to the client."""
pass
class ResourceNotFound(Exception):
"""Base level exceptions raised when a resource is not found."""
pass
class ResourceAlreadyExists(Exception):
"""Base level exceptions raised when a resource already exists."""
pass
class Ratelimited(APIError):
"""Raised when the ratelimit from Github is reached or exceeded."""
def __init__(self, reset_time: datetime.datetime):
formatted = reset_time.strftime(r"%H:%M:%S %A, %d %b")
msg = "We're being ratelimited, wait until {}.\nAuthentication raises the ratelimit.".format(formatted)
super().__init__(msg)
class WillExceedRatelimit(APIError):
"""Raised when the library predicts the call will exceed the ratelimit, will abort the call by default."""
def __init__(self, response: ClientRequest, count: int):
msg = 'Performing this action will exceed the ratelimit, aborting.\n{} remaining available calls, calls to make: {}.'
msg = msg.format(response.headers['X-RateLimit-Remaining'], count)
super().__init__(msg)
class NoAuthProvided(ClientException):
"""Raised when no authentication is provided."""
def __init__(self):
msg = 'This action required autentication. Pass username and token kwargs to your client instance.'
super().__init__(msg)
class InvalidToken(ClientException):
"""Raised when the token provided is invalid."""
def __init__(self):
msg = 'The token provided is invalid.'
super().__init__(msg)
class InvalidAuthCombination(ClientException):
"""Raised when the username and token are both provided."""
def __init__(self):
msg = 'The username and token cannot be used together.'
super().__init__(msg)
class LoginFailure(ClientException):
"""Raised when the login attempt fails."""
def __init__(self):
msg = 'The login attempt failed. Provide valid credentials.'
super().__init__(msg)
class NotStarted(ClientException):
"""Raised when the client is not started."""
def __init__(self):
msg = 'The client is not started. Run Github.GHClient() to start.'
super().__init__(msg)
class AlreadyStarted(ClientException):
"""Raised when the client is already started."""
def __init__(self):
msg = 'The client is already started.'
super().__init__(msg)
class MissingPermissions(APIError):
def __init__(self):
msg = 'You do not have permissions to perform this action.'
super().__init__(msg)
class UserNotFound(ResourceNotFound):
def __init__(self):
msg = 'The requested user was not found.'
super().__init__(msg)
class RepositoryNotFound(ResourceNotFound):
def __init__(self):
msg = 'The requested repository is either private or does not exist.'
super().__init__(msg)
class IssueNotFound(ResourceNotFound):
def __init__(self):
msg = 'The requested issue was not found.'
super().__init__(msg)
class OrganizationNotFound(ResourceNotFound):
def __init__(self):
msg = 'The requested organization was not found.'
super().__init__(msg)
class GistNotFound(ResourceNotFound):
def __init__(self):
msg = 'The requested gist was not found.'
super().__init__(msg)
class RepositoryAlreadyExists(ResourceAlreadyExists):
def __init__(self):
msg = 'The requested repository already exists.'

View file

@ -1,4 +1,4 @@
#== http.py ==#
# == http.py ==#
from __future__ import annotations
@ -29,19 +29,14 @@ Rates = NamedTuple('Rates', 'remaining', 'used', 'total', 'reset_when', 'last_re
# aiohttp request tracking / checking bits
async def on_req_start(
session: aiohttp.ClientSession,
ctx: SimpleNamespace,
params: aiohttp.TraceRequestStartParams
session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestStartParams
) -> None:
"""Before-request hook to make sure we don't overrun the ratelimit."""
#print(repr(session), repr(ctx), repr(params))
# print(repr(session), repr(ctx), repr(params))
pass
async def on_req_end(
session: aiohttp.ClientSession,
ctx: SimpleNamespace,
params: aiohttp.TraceRequestEndParams
) -> None:
async def on_req_end(session: aiohttp.ClientSession, ctx: SimpleNamespace, params: aiohttp.TraceRequestEndParams) -> None:
"""After-request hook to adjust remaining requests on this time frame."""
headers = params.response.headers
@ -53,36 +48,38 @@ async def on_req_end(
session._rates = Rates(remaining, used, total, reset_when, last_req)
trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(on_req_start)
trace_config.on_request_end.append(on_req_end)
APIType: TypeAlias = Union[User, Gist, Repository]
async def make_session(*, headers: Dict[str, str], authorization: Union[aiohttp.BasicAuth, None]) -> aiohttp.ClientSession:
"""This makes the ClientSession, attaching the trace config and ensuring a UA header is present."""
if not headers.get('User-Agent'):
headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python {platform.python_version()} aiohttp {aiohttp.__version__}'
headers[
'User-Agent'
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python {platform.python_version()} aiohttp {aiohttp.__version__}'
session = aiohttp.ClientSession(
auth=authorization,
headers=headers,
trace_configs=[trace_config]
)
session._rates = Rates('', '' , '', '', '')
session = aiohttp.ClientSession(auth=authorization, headers=headers, trace_configs=[trace_config])
session._rates = Rates('', '', '', '', '')
return session
# pagination
class Paginator:
"""This class handles pagination for objects like Repos and Orgs."""
def __init__(self, session: aiohttp.ClientSession, response: aiohttp.ClientResponse, target_type: str):
self.session = session
self.response = response
self.should_paginate = bool(self.response.headers.get('Link', False))
types: Dict[str, Type[APIType]] = { # note: the type checker doesnt see subclasses like that
'user': User,
'gist' : Gist,
'repo' : Repository
'gist': Gist,
'repo': Repository,
}
self.target_type: Type[APIType] = types[target_type]
self.pages = {}
@ -105,7 +102,7 @@ class Paginator:
return await self.early_return()
out: List[APIType] = []
for page in range(1, self.max_page+1):
for page in range(1, self.max_page + 1):
result = await self.session.get(self.bare_link + str(page))
out.extend([self.target_type(item, self) for item in await result.json()]) # type: ignore
@ -121,13 +118,17 @@ class Paginator:
raise WillExceedRatelimit(response, self.max_page)
self.bare_link = groups[0][0][:-1]
# GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = GithubGistData = Dict[str, Union [str, int]]
# Commentnig this out for now, consider using TypeDict's instead in the future <3
class http:
def __init__(self, headers: Dict[str, Union[str, int]], auth: Union[aiohttp.BasicAuth, None]) -> None:
if not headers.get('User-Agent'):
headers['User-Agent'] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
headers[
'User-Agent'
] = f'Github-API-Wrapper (https://github.com/VarMonke/Github-Api-Wrapper) @ {__version__} Python/{platform.python_version()} aiohttp/{aiohttp.__version__}'
self._rates = Rates('', '', '', '', '')
self.headers = headers
@ -149,6 +150,7 @@ class http:
def update_headers(self, *, flush: bool = False, new_headers: Dict[str, Union[str, int]]):
if flush:
from multidict import CIMultiDict
self.session._default_headers = CIMultiDict(**new_headers) # type: ignore
else:
self.session._default_headers = {**self.session.headers, **new_headers} # type: ignore
@ -158,14 +160,10 @@ class http:
headers = self.session.headers
config = self.session.trace_configs
await self.session.close()
self.session = aiohttp.ClientSession(
headers=headers,
auth=auth,
trace_configs=config
)
self.session = aiohttp.ClientSession(headers=headers, auth=auth, trace_configs=config)
def data(self):
#return session headers and auth
# return session headers and auth
headers = {**self.session.headers}
return {'headers': headers, 'auth': self.auth}
@ -175,21 +173,21 @@ class http:
await self.session.get(BASE_URL)
return (datetime.utcnow() - start).total_seconds()
async def get_self(self) -> Dict[str, Union [str, int]]:
async def get_self(self) -> Dict[str, Union[str, int]]:
"""Returns the authenticated User's data"""
result = await self.session.get(SELF_URL)
if 200 <= result.status <= 299:
return await result.json()
raise InvalidToken
async def get_user(self, username: str) -> Dict[str, Union [str, int]]:
async def get_user(self, username: str) -> Dict[str, Union[str, int]]:
"""Returns a user's public data in JSON format."""
result = await self.session.get(USERS_URL.format(username))
if 200 <= result.status <= 299:
return await result.json()
raise UserNotFound
async def get_user_repos(self, _user: User) -> List[Dict[str, Union [str, int]]]:
async def get_user_repos(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_REPOS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
@ -197,7 +195,7 @@ class http:
print('This shouldn\'t be reachable')
return []
async def get_user_gists(self, _user: User) -> List[Dict[str, Union [str, int]]]:
async def get_user_gists(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_GISTS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
@ -205,7 +203,7 @@ class http:
print('This shouldn\'t be reachable')
return []
async def get_user_orgs(self, _user: User) -> List[Dict[str, Union [str, int]]]:
async def get_user_orgs(self, _user: User) -> List[Dict[str, Union[str, int]]]:
result = await self.session.get(USER_ORGS_URL.format(_user.login))
if 200 <= result.status <= 299:
return await result.json()
@ -213,14 +211,14 @@ class http:
print('This shouldn\'t be reachable')
return []
async def get_repo(self, owner: str, repo_name: str) -> Dict[str, Union [str, int]]:
async def get_repo(self, owner: str, repo_name: str) -> Dict[str, Union[str, int]]:
"""Returns a Repo's raw JSON from the given owner and repo name."""
result = await self.session.get(REPO_URL.format(owner, repo_name))
if 200 <= result.status <= 299:
return await result.json()
raise RepositoryNotFound
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Dict[str, Union [str, int]]:
async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> Dict[str, Union[str, int]]:
"""Returns a single issue's JSON from the given owner and repo name."""
result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number))
if 200 <= result.status <= 299:
@ -245,14 +243,14 @@ class http:
raise MissingPermissions
raise GistNotFound
async def get_org(self, org_name: str) -> Dict[str, Union [str, int]]:
async def get_org(self, org_name: str) -> Dict[str, Union[str, int]]:
"""Returns an org's public data in JSON format."""
result = await self.session.get(ORG_URL.format(org_name))
if 200 <= result.status <= 299:
return await result.json()
raise OrganizationNotFound
async def get_gist(self, gist_id: int) -> Dict[str, Union [str, int]]:
async def get_gist(self, gist_id: int) -> Dict[str, Union[str, int]]:
"""Returns a gist's raw JSON from the given gist id."""
result = await self.session.get(GIST_URL.format(gist_id))
if 200 <= result.status <= 299:
@ -260,29 +258,26 @@ class http:
raise GistNotFound
async def create_gist(
self,
*,
files: List['File'] = [],
description: str = 'Default description',
public: bool = False
) -> Dict[str, Union [str, int]]:
self, *, files: List['File'] = [], description: str = 'Default description', public: bool = False
) -> Dict[str, Union[str, int]]:
data = {}
data['description'] = description
data['public'] = public
data['files'] = {}
for file in files:
data['files'][file.filename] = {
'filename' : file.filename, # helps editing the file
'content': file.read()
}
data['files'][file.filename] = {'filename': file.filename, 'content': file.read()} # helps editing the file
data = json.dumps(data)
_headers = dict(self.session.headers)
result = await self.session.post(CREATE_GIST_URL, data=data, headers=_headers|{'Accept': 'application/vnd.github.v3+json'})
result = await self.session.post(
CREATE_GIST_URL, data=data, headers=_headers | {'Accept': 'application/vnd.github.v3+json'}
)
if 201 == result.status:
return await result.json()
raise InvalidToken
async def create_repo(self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]) -> Dict[str, Union [str, int]]:
async def create_repo(
self, name: str, description: str, public: bool, gitignore: Optional[str], license: Optional[str]
) -> Dict[str, Union[str, int]]:
"""Creates a repo for you with given data"""
data = {
'name': name,

View file

@ -74,9 +74,7 @@ class GHClient:
return f'<{self.__class__.__name__} has_auth={bool(self._auth)}>'
def __del__(self):
asyncio.create_task(
self.http.session.close(), name='cleanup-session-github-api-wrapper'
)
asyncio.create_task(self.http.session.close(), name='cleanup-session-github-api-wrapper')
@overload
def check_limits(self, as_dict: Literal[True] = True) -> Dict[str, Union[str, int]]:
@ -86,9 +84,7 @@ class GHClient:
def check_limits(self, as_dict: Literal[False] = False) -> List[str]:
...
def check_limits(
self, as_dict: bool = False
) -> Union[Dict[str, Union[str, int]], List[str]]:
def check_limits(self, as_dict: bool = False) -> Union[Dict[str, Union[str, int]], List[str]]:
if not self.has_started:
raise exceptions.NotStarted
if not as_dict:
@ -132,13 +128,9 @@ class GHClient:
]:
def wrapper(
func: Callable[Concatenate[Self, P], Awaitable[T]]
) -> Callable[
Concatenate[Self, P], Awaitable[Optional[Union[T, User, Repository]]]
]:
) -> Callable[Concatenate[Self, P], Awaitable[Optional[Union[T, User, Repository]]]]:
@functools.wraps(func)
async def wrapped(
self: Self, *args: P.args, **kwargs: P.kwargs
) -> Optional[Union[T, User, Repository]]:
async def wrapped(self: Self, *args: P.args, **kwargs: P.kwargs) -> Optional[Union[T, User, Repository]]:
if type == 'user':
if obj := self._user_cache.get(kwargs.get('user')):
return obj
@ -176,9 +168,7 @@ class GHClient:
async def get_issue(self, *, owner: str, repo: str, issue: int) -> Issue:
"""Fetch a Github Issue from it's name."""
return Issue(
await self.http.get_repo_issue(owner, repo, issue), self.http
)
return Issue(await self.http.get_repo_issue(owner, repo, issue), self.http)
async def create_repo(
self,
@ -202,14 +192,10 @@ class GHClient:
"""Fetch a Github gist from it's id."""
return Gist(await self.http.get_gist(gist), self.http)
async def create_gist(
self, *, files: List[File], description: str, public: bool
) -> Gist:
async def create_gist(self, *, files: List[File], description: str, public: bool) -> Gist:
"""Creates a Gist with the given files, requires authorisation."""
return Gist(
await self.http.create_gist(
files=files, description=description, public=public
),
await self.http.create_gist(files=files, description=description, public=public),
self.http,
)

View file

@ -1,4 +1,4 @@
#== objects.py ==#
# == objects.py ==#
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, Dict
@ -23,23 +23,22 @@ __all__ = (
'Organization',
)
def dt_formatter(time_str: str) -> Optional[datetime]:
if time_str is not None:
return datetime.strptime(time_str, r"%Y-%m-%dT%H:%M:%SZ")
return None
def repr_dt(_datetime: datetime) -> str:
return _datetime.strftime(r'%d-%m-%Y, %H:%M:%S')
class APIObject:
__slots__: Tuple[str, ...] = (
'_response',
'_http'
)
__slots__: Tuple[str, ...] = ('_response', '_http')
def __init__(self, response: Dict[str, Any] , _http: http) -> None:
def __init__(self, response: Dict[str, Any], _http: http) -> None:
self._http = _http
self._response = response
@ -47,13 +46,15 @@ class APIObject:
return f'<{self.__class__.__name__}>'
#=== User stuff ===#
# === User stuff ===#
class _BaseUser(APIObject):
__slots__ = (
'login',
'id',
)
def __init__(self, response: Dict[str, Any], _http: http) -> None:
super().__init__(response, _http)
self._http = _http
@ -88,10 +89,11 @@ class User(_BaseUser):
'following',
'created_at',
)
def __init__(self, response: Dict[str, Any], _http: http) -> None:
super().__init__(response, _http)
tmp = self.__slots__ + _BaseUser.__slots__
keys = {key: value for key,value in self._response.items() if key in tmp}
keys = {key: value for key, value in self._response.items() if key in tmp}
for key, value in keys.items():
if '_at' in key and value is not None:
setattr(self, key, dt_formatter(value))
@ -126,7 +128,8 @@ class PartialUser(_BaseUser):
return User(response, self._http)
#=== Repository stuff ===#
# === Repository stuff ===#
class Repository(APIObject):
if TYPE_CHECKING:
@ -138,8 +141,7 @@ class Repository(APIObject):
'id',
'name',
'owner',
'size'
'created_at',
'size' 'created_at',
'url',
'html_url',
'archived',
@ -152,10 +154,11 @@ class Repository(APIObject):
'watchers_count',
'license',
)
def __init__(self, response: Dict[str, Any], _http: http) -> None:
super().__init__(response, _http)
tmp = self.__slots__ + APIObject.__slots__
keys = {key: value for key,value in self._response.items() if key in tmp}
keys = {key: value for key, value in self._response.items() if key in tmp}
for key, value in keys.items():
if key == 'owner':
setattr(self, key, PartialUser(value, self._http))
@ -198,6 +201,7 @@ class Repository(APIObject):
def forks(self) -> int:
return self._response.get('forks')
class Issue(APIObject):
__slots__ = (
'id',
@ -212,7 +216,7 @@ class Issue(APIObject):
def __init__(self, response: Dict[str, Any], _http: http) -> None:
super().__init__(response, _http)
tmp = self.__slots__ + APIObject.__slots__
keys = {key: value for key,value in self._response.items() if key in tmp}
keys = {key: value for key, value in self._response.items() if key in tmp}
for key, value in keys.items():
if key == 'user':
setattr(self, key, PartialUser(value, self._http))
@ -241,7 +245,9 @@ class Issue(APIObject):
def html_url(self) -> str:
return self._response.get('html_url')
#=== Gist stuff ===#
# === Gist stuff ===#
class File:
def __init__(self, fp: Union[str, io.StringIO], filename: str = 'DefaultFilename.txt') -> None:
@ -264,6 +270,7 @@ class File:
raise TypeError(f'Expected str, io.StringIO, or io.BytesIO, got {type(self.fp)}')
class Gist(APIObject):
__slots__ = (
'id',
@ -275,10 +282,11 @@ class Gist(APIObject):
'created_at',
'truncated',
)
def __init__(self, response: Dict[str, Any], _http: http) -> None:
super().__init__(response, _http)
tmp = self.__slots__ + APIObject.__slots__
keys = {key: value for key,value in self._response.items() if key in tmp}
keys = {key: value for key, value in self._response.items() if key in tmp}
for key, value in keys.items():
if key == 'owner':
setattr(self, key, PartialUser(value, self._http))
@ -309,7 +317,8 @@ class Gist(APIObject):
return self._response
#=== Organization stuff ===#
# === Organization stuff ===#
class Organization(APIObject):
__slots__ = (
@ -327,7 +336,7 @@ class Organization(APIObject):
def __init__(self, response: Dict[str, Any], _http: http) -> None:
super().__init__(response, _http)
tmp = self.__slots__ + APIObject.__slots__
keys = {key: value for key,value in self._response.items() if key in tmp}
keys = {key: value for key, value in self._response.items() if key in tmp}
for key, value in keys.items():
if key == 'login':
setattr(self, key, value)

View file

@ -1,9 +1,9 @@
#== urls.py ==#
# == urls.py ==#
BASE_URL = 'https://api.github.com'
#== user urls ==#
# == user urls ==#
USERS_URL = BASE_URL + '/users/{0}'
USER_HTML_URL = 'https://github.com/users/{0}'
@ -21,8 +21,8 @@ USER_FOLLOWERS_URL = USERS_URL + '/followers'
USER_FOLLOWING_URL = USERS_URL + '/following'
#== repo urls ==#
CREATE_REPO_URL = BASE_URL + '/user/repos' #_auth repo create
# == repo urls ==#
CREATE_REPO_URL = BASE_URL + '/user/repos' # _auth repo create
REPOS_URL = BASE_URL + '/repos/{0}' # repos of a user
@ -30,10 +30,10 @@ REPO_URL = BASE_URL + '/repos/{0}/{1}' # a specific repo
REPO_ISSUE_URL = REPO_URL + '/issues/{2}' # a specific issue
#== gist urls ==#
# == gist urls ==#
GIST_URL = BASE_URL + '/gists/{0}' # specific gist
CREATE_GIST_URL = BASE_URL + '/gists' # create a gist
#== org urls ==#
# == org urls ==#
ORG_URL = BASE_URL + '/orgs/{0}'