diff --git a/Github/http.py b/Github/http.py index c07d947..4250411 100644 --- a/Github/http.py +++ b/Github/http.py @@ -13,11 +13,8 @@ from .objects import * from .urls import * __all__ = ( - 'make_session', 'Paginator', - 'get_user', - 'get_repo_from_name', - 'get_repo_issue', + 'http', ) @@ -112,53 +109,92 @@ class Paginator: raise WillExceedRatelimit(response, self.max_page) self.bare_link = groups[0][0][:-1] -# user-related functions / utils -async def get_self(session: aiohttp.ClientSession) -> User: - result = await session.get(SELF_URL) - if result.status == 200: - return User(await result.json(), session) - raise InvalidToken +GithubUserData = GithubRepoData = GithubIssueData = GithubOrgData = dict[str, str | int] -async def get_user(session: aiohttp.ClientSession, username: str) -> User: - """Returns a user's public data in JSON format.""" - result = await session.get(USERS_URL.format(username)) - if result.status == 200: - return User(await result.json(), session) - raise UserNotFound +class http: + def __init__(self, headers: dict[str, str | int], auth: aiohttp.BasicAuth | None): + if not headers.get('User-Agent'): + headers['User-Agent'] = 'Github-API-Wrapper' + self._rates = Rates('', '', '', '', '') + self.headers = headers + self.auth = auth + def __await__(self): + return self.start().__await__() -# repo-related functions / utils -async def get_repo_from_name(session: aiohttp.ClientSession, owner: str, repo: str) -> Repository: - """Returns a Repo object from the given owner and repo name.""" - result = await session.get(REPO_URL.format(owner, repo)) - if result.status == 200: - return Repository(await result.json(), session) - raise RepositoryNotFound + async def start(self): + self.session = aiohttp.ClientSession( + headers=self.headers, + auth=self.auth, + trace_configs=[trace_config], + ) + return self -async def get_repo_issue(session: aiohttp.ClientSession, owner: str, repo: str, issue: int) -> Issue: - """Returns a single issue from the given owner and repo name.""" - result = await session.get(REPO_ISSUE_URL.format(owner, repo, issue)) - if result.status == 200: - return Issue(await result.json(), session) - raise IssueNotFound + def update_headers(self, *, flush: bool = False, new_headers: dict[str, str | int]): + if flush: + from multidict import CIMultiDict + self.session.headers = CIMultiDict(**new_headers) + else: + self.session.headers = {**self.session.headers, **new_headers} -async def create_repo(session: aiohttp.ClientSession, name: str, description: str, private: bool, gitignore_template: str, **kwargs) -> Repository: - """Creates a new repo with the given name.""" - _data = {"name" : name, "description" : description, "private" : private, "gitignore_template" : gitignore_template} - result = await session.post(MAKE_REPO_URL, data= json.dumps(_data)) - if result.status == 201: - return Repository(await result.json(), session) - if result.status == 401: - raise NoAuthProvided - raise RepositoryAlreadyExists + async def update_auth(self, *, username: str, token: str): + auth = aiohttp.BasicAuth(username, token) + headers = self.session.headers + config = self.session.trace_configs + await self.session.close() + self.session = aiohttp.ClientSession( + headers=headers, + auth=auth, + trace_configs=config + ) + async def get_self(self) -> GithubUserData: + """Returns the authenticated User's data""" + result = await self.session.get(SELF_URL) + if 200 <= result.status <= 299: + return await result.json() + raise InvalidToken + async def get_user(self, username: str) -> GithubUserData: + """Returns a user's public data in JSON format.""" + result = await self.session.get(USERS_URL.format(username)) + if 200 <= result.status <= 299: + return await result.json() + raise UserNotFound -# org-related functions / utils + async def get_repo(self, owner: str, repo_name: str) -> GithubRepoData: + """Returns a Repo's raw JSON from the given owner and repo name.""" + result = await self.session.get(REPO_URL.format(owner, repo_name)) + if 200 <= result.status <= 299: + return await result.json() + raise RepositoryNotFound + + async def get_repo_issue(self, owner: str, repo_name: str, issue_number: int) -> GithubIssueData: + """Returns a single issue's JSON from the given owner and repo name.""" + result = await self.session.get(REPO_ISSUE_URL.format(owner, repo_name, issue_number)) + if 200 <= result.status <= 299: + return await result.json() + raise IssueNotFound -async def get_org(session: aiohttp.ClientSession, org: str) -> Organization: - """Returns an org's public data in JSON format.""" - result = await session.get(ORG_URL.format(org)) - if result.status == 200: - return Organization(await result.json(), session) - raise OrganizationNotFound \ No newline at end of file + async def create_repo(self, **kwargs: dict[str, str | bool]) -> GithubRepoData: + """Creates a repo for you with given data""" + data = { + 'name': kwargs.get('name'), + 'description': kwargs.get('description'), + 'private': kwargs.get('private', True), + 'gitignore_template': kwargs.get('gitignore'), + 'license': kwargs.get('license'), + } + result = await self.session.post(CREATE_REPO_URL, data=json.dumps(data)) + if 200 <= result.status <= 299: + return await result.json() + if result.status == 401: + raise NoAuthProvided + raise RepositoryAlreadyExists + + async def get_org(self, org_name: str) -> GithubOrgData: + """Returns an org's public data in JSON format.""" + result = await self.session.get(ORG_URL.format(org_name)) + if 200 <= result.status <= 299: + return await result.json() + raise OrganizationNotFound diff --git a/Github/main.py b/Github/main.py index d761f79..1623928 100644 --- a/Github/main.py +++ b/Github/main.py @@ -9,7 +9,7 @@ import asyncio import functools from getpass import getpass -from . import http +from .http import http from . import exceptions from .objects import User, PartialUser, Repository, Organization, Issue from .cache import UserCache, RepoCache @@ -17,6 +17,7 @@ from .cache import UserCache, RepoCache class GHClient: _auth = None has_started = False + http: http def __init__( self, *, @@ -43,36 +44,34 @@ class GHClient: return f'' def __del__(self): - asyncio.create_task(self.session.close()) + asyncio.create_task(self.http.session.close()) def check_limits(self, as_dict: bool = False) -> dict[str, str | int] | list[str]: if not self.has_started: raise exceptions.NotStarted if not as_dict: output = [] - for key, value in self.session._rates._asdict().items(): + for key, value in self.http.session._rates._asdict().items(): output.append(f'{key} : {value}') return output - return self.session._rates._asdict() + return self.http.session._rates._asdict() - def update_auth(self) -> None: + async def update_auth(self, username: str, token: str) -> None: """Allows you to input auth information after instantiating the client.""" - username = input('Enter your username: ') - token = getpass('Enter your token: ') - self._auth = aiohttp.BasicAuth(username, token) + await self.http.update_auth(username, token) async def start(self) -> 'GHClient': """Main entry point to the wrapper, this creates the ClientSession.""" if self.has_started: raise exceptions.AlreadyStarted if self._auth: - self.session = await http.make_session(headers=self._headers, authorization=self._auth) + self.http = await http(auth=self._auth, headers=self._headers) try: - await self.get_self() + await self.http.get_self() except exceptions.InvalidToken as exc: raise exceptions.InvalidToken from exc else: - self.session = await http.make_session(authorization = self._auth, headers = self._headers) + self.http = await http(auth=None, headers=self._headers) self.has_started = True return self @@ -102,30 +101,30 @@ class GHClient: async def get_self(self) -> User: """Returns the authenticated User object.""" if self._auth: - return await http.get_self(self.session) + return User(await self.http.get_self(), self.http.session) else: raise exceptions.NoAuthProvided @_cache(type='User') - async def get_user(self, username) -> User: + async def get_user(self, username: str) -> User: """Fetch a Github user from their username.""" - return await http.get_user(self.session, username) + return User(await self.http.get_user(username), self.http.session) @_cache(type='Repo') async def get_repo(self, owner: str, repo: str) -> Repository: """Fetch a Github repository from it's name.""" - return await http.get_repo_from_name(self.session, owner, repo) + return Repository(await self.http.get_repo(owner, repo), self.http.session) async def get_issue(self, owner: str, repo: str, issue: int) -> Issue: """Fetch a Github repository from it's name.""" - return await http.get_repo_issue(self.session, owner, repo, issue) + return Issue(await self.http.get_repo_issue(owner, repo, issue), self.http.session) async def create_repo(self, name: str, description: str, private: bool, gitignore_template: str) -> Repository: """Create a new Github repository.""" - return await http.make_repo(self.session, name, description, private, gitignore_template) + return Repository(await self.http.make_repo(name, description, private, gitignore_template), self.http.session) - async def get_org(self, org) -> Organization: + async def get_org(self, org: str) -> Organization: """Fetch a Github organization from it's name""" - return await http.get_org(self.session, org) + return Organization(await http.get_org(org), self.http.session)