1
Fork 0
mirror of https://github.com/RGBCube/GitHubWrapper synced 2025-05-25 18:25:09 +00:00

Added a class to wrap around the session to allow raw returns

This commit is contained in:
sudosnok 2022-04-06 20:33:15 +01:00
parent 23bc1f5b74
commit b297e81bd0
2 changed files with 99 additions and 64 deletions

View file

@ -13,11 +13,8 @@ from .objects import *
from .urls import * from .urls import *
__all__ = ( __all__ = (
'make_session',
'Paginator', 'Paginator',
'get_user', 'http',
'get_repo_from_name',
'get_repo_issue',
) )
@ -112,53 +109,92 @@ class Paginator:
raise WillExceedRatelimit(response, self.max_page) raise WillExceedRatelimit(response, self.max_page)
self.bare_link = groups[0][0][:-1] self.bare_link = groups[0][0][:-1]
# user-related functions / utils GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = dict[str, str | int]
async def get_self(session: aiohttp.ClientSession) -> User:
result = await session.get(SELF_URL) class http:
if result.status == 200: def __init__(self, headers: dict[str, str | int], auth: aiohttp.BasicAuth | None):
return User(await result.json(), session) if not headers.get('User-Agent'):
headers['User-Agent'] = 'Github-API-Wrapper'
self._rates = Rates('', '', '', '', '')
self.headers = headers
self.auth = auth
def __await__(self):
return self.start().__await__()
async def start(self):
self.session = aiohttp.ClientSession(
headers=self.headers,
auth=self.auth,
trace_configs=[trace_config],
)
return self
def update_headers(self, *, flush: bool = False, new_headers: dict[str, str | int]):
if flush:
from multidict import CIMultiDict
self.session.headers = CIMultiDict(**new_headers)
else:
self.session.headers = {**self.session.headers, **new_headers}
async def update_auth(self, *, username: str, token: str):
auth = aiohttp.BasicAuth(username, token)
headers = self.session.headers
config = self.session.trace_configs
await self.session.close()
self.session = aiohttp.ClientSession(
headers=headers,
auth=auth,
trace_configs=config
)
async def get_self(self) -> GithubUserData:
"""Returns the authenticated User's data"""
result = await self.session.get(SELF_URL)
if 200 <= result.status <= 299:
return await result.json()
raise InvalidToken raise InvalidToken
async def get_user(session: aiohttp.ClientSession, username: str) -> User: async def get_user(self, username: str) -> GithubUserData:
"""Returns a user's public data in JSON format.""" """Returns a user's public data in JSON format."""
result = await session.get(USERS_URL.format(username)) result = await self.session.get(USERS_URL.format(username))
if result.status == 200: if 200 <= result.status <= 299:
return User(await result.json(), session) return await result.json()
raise UserNotFound raise UserNotFound
async def get_repo(self, owner: str, repo_name: str) -> GithubRepoData:
# repo-related functions / utils """Returns a Repo's raw JSON from the given owner and repo name."""
async def get_repo_from_name(session: aiohttp.ClientSession, owner: str, repo: str) -> Repository: result = await self.session.get(REPO_URL.format(owner, repo_name))
"""Returns a Repo object from the given owner and repo name.""" if 200 <= result.status <= 299:
result = await session.get(REPO_URL.format(owner, repo)) return await result.json()
if result.status == 200:
return Repository(await result.json(), session)
raise RepositoryNotFound raise RepositoryNotFound
async def get_repo_issue(session: aiohttp.ClientSession, owner: str, repo: str, issue: int) -> Issue: async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> GithubIssueData:
"""Returns a single issue from the given owner and repo name.""" """Returns a single issue's JSON from the given owner and repo name."""
result = await session.get(REPO_ISSUE_URL.format(owner, repo, issue)) result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number))
if result.status == 200: if 200 <= result.status <= 299:
return Issue(await result.json(), session) return await result.json()
raise IssueNotFound raise IssueNotFound
async def create_repo(session: aiohttp.ClientSession, name: str, description: str, private: bool, gitignore_template: str, **kwargs) -> Repository: async def create_repo(self, **kwargs: dict[str, str | bool]) -> GithubRepoData:
"""Creates a new repo with the given name.""" """Creates a repo for you with given data"""
_data = {"name" : name, "description" : description, "private" : private, "gitignore_template" : gitignore_template} data = {
result = await session.post(MAKE_REPO_URL, data= json.dumps(_data)) 'name': kwargs.get('name'),
if result.status == 201: 'description': kwargs.get('description'),
return Repository(await result.json(), session) 'private': kwargs.get('private', True),
'gitignore_template': kwargs.get('gitignore'),
'license': kwargs.get('license'),
}
result = await self.session.post(CREATE_REPO_URL, data=json.dumps(data))
if 200 <= result.status <= 299:
return await result.json()
if result.status == 401: if result.status == 401:
raise NoAuthProvided raise NoAuthProvided
raise RepositoryAlreadyExists raise RepositoryAlreadyExists
async def get_org(self, org_name: str) -> GithubOrgData:
# org-related functions / utils
async def get_org(session: aiohttp.ClientSession, org: str) -> Organization:
"""Returns an org's public data in JSON format.""" """Returns an org's public data in JSON format."""
result = await session.get(ORG_URL.format(org)) result = await self.session.get(ORG_URL.format(org_name))
if result.status == 200: if 200 <= result.status <= 299:
return Organization(await result.json(), session) return await result.json()
raise OrganizationNotFound raise OrganizationNotFound

View file

@ -9,7 +9,7 @@ import asyncio
import functools import functools
from getpass import getpass from getpass import getpass
from . import http from .http import http
from . import exceptions from . import exceptions
from .objects import User, PartialUser, Repository, Organization, Issue from .objects import User, PartialUser, Repository, Organization, Issue
from .cache import UserCache, RepoCache from .cache import UserCache, RepoCache
@ -17,6 +17,7 @@ from .cache import UserCache, RepoCache
class GHClient: class GHClient:
_auth = None _auth = None
has_started = False has_started = False
http: http
def __init__( def __init__(
self, self,
*, *,
@ -43,36 +44,34 @@ class GHClient:
return f'<Github Client; has_auth={bool(self._auth)}>' return f'<Github Client; has_auth={bool(self._auth)}>'
def __del__(self): def __del__(self):
asyncio.create_task(self.session.close()) asyncio.create_task(self.http.session.close())
def check_limits(self, as_dict: bool = False) -> dict[str, str | int] | list[str]: def check_limits(self, as_dict: bool = False) -> dict[str, str | int] | list[str]:
if not self.has_started: if not self.has_started:
raise exceptions.NotStarted raise exceptions.NotStarted
if not as_dict: if not as_dict:
output = [] output = []
for key, value in self.session._rates._asdict().items(): for key, value in self.http.session._rates._asdict().items():
output.append(f'{key} : {value}') output.append(f'{key} : {value}')
return output return output
return self.session._rates._asdict() return self.http.session._rates._asdict()
def update_auth(self) -> None: async def update_auth(self, username: str, token: str) -> None:
"""Allows you to input auth information after instantiating the client.""" """Allows you to input auth information after instantiating the client."""
username = input('Enter your username: ') await self.http.update_auth(username, token)
token = getpass('Enter your token: ')
self._auth = aiohttp.BasicAuth(username, token)
async def start(self) -> 'GHClient': async def start(self) -> 'GHClient':
"""Main entry point to the wrapper, this creates the ClientSession.""" """Main entry point to the wrapper, this creates the ClientSession."""
if self.has_started: if self.has_started:
raise exceptions.AlreadyStarted raise exceptions.AlreadyStarted
if self._auth: if self._auth:
self.session = await http.make_session(headers=self._headers, authorization=self._auth) self.http = await http(auth=self._auth, headers=self._headers)
try: try:
await self.get_self() await self.http.get_self()
except exceptions.InvalidToken as exc: except exceptions.InvalidToken as exc:
raise exceptions.InvalidToken from exc raise exceptions.InvalidToken from exc
else: else:
self.session = await http.make_session(authorization = self._auth, headers = self._headers) self.http = await http(auth=None, headers=self._headers)
self.has_started = True self.has_started = True
return self return self
@ -102,30 +101,30 @@ class GHClient:
async def get_self(self) -> User: async def get_self(self) -> User:
"""Returns the authenticated User object.""" """Returns the authenticated User object."""
if self._auth: if self._auth:
return await http.get_self(self.session) return User(await self.http.get_self(), self.http.session)
else: else:
raise exceptions.NoAuthProvided raise exceptions.NoAuthProvided
@_cache(type='User') @_cache(type='User')
async def get_user(self, username) -> User: async def get_user(self, username: str) -> User:
"""Fetch a Github user from their username.""" """Fetch a Github user from their username."""
return await http.get_user(self.session, username) return User(await self.http.get_user(username), self.http.session)
@_cache(type='Repo') @_cache(type='Repo')
async def get_repo(self, owner: str, repo: str) -> Repository: async def get_repo(self, owner: str, repo: str) -> Repository:
"""Fetch a Github repository from it's name.""" """Fetch a Github repository from it's name."""
return await http.get_repo_from_name(self.session, owner, repo) return Repository(await self.http.get_repo(owner, repo), self.http.session)
async def get_issue(self, owner: str, repo: str, issue: int) -> Issue: async def get_issue(self, owner: str, repo: str, issue: int) -> Issue:
"""Fetch a Github repository from it's name.""" """Fetch a Github repository from it's name."""
return await http.get_repo_issue(self.session, owner, repo, issue) return Issue(await self.http.get_repo_issue(owner, repo, issue), self.http.session)
async def create_repo(self, name: str, description: str, private: bool, gitignore_template: str) -> Repository: async def create_repo(self, name: str, description: str, private: bool, gitignore_template: str) -> Repository:
"""Create a new Github repository.""" """Create a new Github repository."""
return await http.make_repo(self.session, name, description, private, gitignore_template) return Repository(await self.http.make_repo(name, description, private, gitignore_template), self.http.session)
async def get_org(self, org) -> Organization: async def get_org(self, org: str) -> Organization:
"""Fetch a Github organization from it's name""" """Fetch a Github organization from it's name"""
return await http.get_org(self.session, org) return Organization(await http.get_org(org), self.http.session)