diff --git a/juju/client/connection.py b/juju/client/connection.py index 45ee6d95..7980ce03 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -256,7 +256,8 @@ async def connect( specified_facades=None, proxy=None, debug_log_conn=None, - debug_log_params={} + debug_log_params={}, + cookie_domain=None, ): """Connect to the websocket. @@ -285,8 +286,10 @@ async def connect( to prevent using the conservative client pinning with in the client. :param TextIOWrapper debug_log_conn: target if this is a debug log connection :param dict debug_log_params: filtering parameters for the debug-log output + :param str cookie_domain: Which domain the controller uses for cookies. """ self = cls() + self.cookie_domain = cookie_domain if endpoint is None: raise ValueError('no endpoint provided') if not isinstance(endpoint, str) and not isinstance(endpoint, list): @@ -301,7 +304,7 @@ async def connect( if password is not None: raise errors.JujuAuthError('cannot log in as external ' 'user with a password') - username = None + self.usertag = tag.user(username) self.password = password @@ -762,6 +765,7 @@ def connect_params(self): 'bakery_client': self.bakery_client, 'max_frame_size': self.max_frame_size, 'proxy': self.proxy, + 'cookie_domain': self.cookie_domain, } async def controller(self): @@ -774,6 +778,7 @@ async def controller(self): cacert=self.cacert, bakery_client=self.bakery_client, max_frame_size=self.max_frame_size, + cookie_domain=self.cookie_domain, ) async def reconnect(self): @@ -871,7 +876,7 @@ async def _connect_with_login(self, endpoints): # a few times. for i in range(0, 2): result = (await self.login())['response'] - macaroonJSON = result.get('discharge-required') + macaroonJSON = result.get('bakery-discharge-required') if macaroonJSON is None: self.info = result success = True @@ -962,7 +967,7 @@ async def login(self): params['credentials'] = self.password else: macaroons = _macaroons_for_domain(self.bakery_client.cookies, - self.endpoint) + self.cookie_domain) params['macaroons'] = [[bakery.macaroon_to_dict(m) for m in ms] for ms in macaroons] diff --git a/juju/client/connector.py b/juju/client/connector.py index 3a901c7d..ca8a7289 100644 --- a/juju/client/connector.py +++ b/juju/client/connector.py @@ -63,17 +63,33 @@ async def connect(self, **kwargs): kwargs are passed through to Connection.connect() """ - kwargs.setdefault("max_frame_size", self.max_frame_size) - kwargs.setdefault("bakery_client", self.bakery_client) - if "macaroons" in kwargs: - if not kwargs["bakery_client"]: - kwargs["bakery_client"] = httpbakery.Client() - if not kwargs["bakery_client"].cookies: - kwargs["bakery_client"].cookies = GoCookieJar() - jar = kwargs["bakery_client"].cookies - for macaroon in kwargs.pop("macaroons"): - jar.set_cookie(go_to_py_cookie(macaroon)) - if "debug_log_conn" in kwargs: + kwargs.setdefault('max_frame_size', self.max_frame_size) + kwargs.setdefault('bakery_client', self.bakery_client) + kwargs.setdefault('cookie_domain', self.jujudata.cookie_domain_for_controller(endpoint=kwargs.get("endpoints"))) + + account = kwargs.pop('account', {}) + # Prioritize the username and password that user provided + # If not enough, try to patch it with info from accounts.yaml + if 'username' not in kwargs and account.get('user'): + kwargs.update(username=account.get('user')) + if 'password' not in kwargs and account.get('password'): + kwargs.update(password=account.get('password')) + + if not kwargs["bakery_client"]: + if 'macaroons' in kwargs: + if not kwargs['bakery_client']: + kwargs['bakery_client'] = httpbakery.Client() + if not kwargs['bakery_client'].cookies: + kwargs['bakery_client'].cookies = GoCookieJar() + jar = kwargs['bakery_client'].cookies + for macaroon in kwargs.pop('macaroons'): + jar.set_cookie(go_to_py_cookie(macaroon)) + else: + if not ({'username', 'password'}.issubset(kwargs)): + required = {'username', 'password'}.difference(kwargs) + raise ValueError(f'Some authentication parameters are required : {",".join(required)}') + + if 'debug_log_conn' in kwargs: assert self._connection self._log_connection = await Connection.connect(**kwargs) else: @@ -85,18 +101,6 @@ async def connect(self, **kwargs): # connected to. if self._connection: await self._connection.close() - - account = kwargs.pop('account', {}) - # Prioritize the username and password that user provided - # If not enough, try to patch it with info from accounts.yaml - if 'username' not in kwargs and account.get('user'): - kwargs.update(username=account.get('user')) - if 'password' not in kwargs and account.get('password'): - kwargs.update(password=account.get('password')) - - if not ({'username', 'password'}.issubset(kwargs)): - required = {'username', 'password'}.difference(kwargs) - raise ValueError(f'Some authentication parameters are required : {",".join(required)}') self._connection = await Connection.connect(**kwargs) # Check if we support the target controller diff --git a/juju/client/gocookies.py b/juju/client/gocookies.py index e53ccde9..ad1ad327 100644 --- a/juju/client/gocookies.py +++ b/juju/client/gocookies.py @@ -5,14 +5,30 @@ import http.cookiejar as cookiejar import json import time - import pyrfc3339 +class JujuCookiePolicy(cookiejar.DefaultCookiePolicy): + '''A cookie policy that allows arbitrary strings to be used for cookie domains + as long as they match the host in the request exactly. This is necessary for interacting + with the juju controller, which uses UUID's for cookie domains in some circumstances.''' + def return_ok_domain(self, cookie, request): + return cookie.domain == request.host or super().return_ok_domain(cookie, request) + + def domain_return_ok(self, domain, request): + return domain == request.host or super().domain_return_ok(domain, request) + + class GoCookieJar(cookiejar.FileCookieJar): '''A CookieJar implementation that reads and writes cookies to the cookiejar format as understood by the Go package github.com/juju/persistent-cookiejar.''' + def __init__(self, filename=None, delayload=False, policy=None): + if policy is None: + policy = JujuCookiePolicy() + + super().__init__(filename, delayload, policy) + def _really_load(self, f, filename, ignore_discard, ignore_expires): '''Implement the _really_load method called by FileCookieJar to implement the actual cookie loading''' diff --git a/juju/client/jujudata.py b/juju/client/jujudata.py index f48858b3..52605bfd 100644 --- a/juju/client/jujudata.py +++ b/juju/client/jujudata.py @@ -182,3 +182,28 @@ def cookies_for_controller(self, controller_name): jar = GoCookieJar(str(f)) jar.load() return jar + + def cookie_domain_for_controller(self, controller_name=None, endpoint=None): + '''Returns the correct cookie domain. + + The cookie domain used by the controller is either the + field public-hostname in the controllers.yaml file, or the uuid if this is + not available. If neither controller_name nor endpoint are specified, assume + the current controller. + + :param str controller_name: The name of the controller. + :param str endpoint: The endpoint of the controller. + ''' + if all([controller_name, endpoint]): + raise ValueError('Specify either controller_name or endpoint, not both') + if controller_name: + controller = controller_name + elif endpoint: + controller = self.controller_name_by_endpoint(endpoint) + else: + controller = self.current_controller() + + controllers = self.controllers() + controller_data = controllers.get(controller) + + return controller_data.get('public-hostname', controller_data.get('uuid'))