Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions duo_universal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _create_jwt_args(self, endpoint):

def __init__(self, client_id, client_secret, host,
redirect_uri, duo_certs=DEFAULT_CA_CERT_PATH, use_duo_code_attribute=True, http_proxy=None,
exp_seconds=FIVE_MINUTES_IN_SECONDS):
exp_seconds=FIVE_MINUTES_IN_SECONDS, disable_ca_pinning=False):
"""
Initializes instance of Client class

Expand All @@ -144,6 +144,10 @@ def __init__(self, client_id, client_secret, host,
use_duo_code_attribute -- (Optional: default true) Flag to use `duo_code` instead of `code` for returned authorization parameter
http_proxy -- (Optional) HTTP proxy to tunnel requests through
exp_seconds -- (Optional) The number of seconds used for JWT expiry. Must be be at most 5 minutes.
disable_ca_pinning -- (Optional: default false) If True, uses the system's default
trusted CA certificates instead of Duo's bundled CA certificates.
TLS verification remains active. Cannot be used together with
custom duo_certs.
"""

self._validate_init_config(client_id,
Expand All @@ -158,9 +162,16 @@ def __init__(self, client_id, client_secret, host,
self._redirect_uri = redirect_uri
self._use_duo_code_attribute = use_duo_code_attribute

# If duo_certs is None set it to the DEFAULT_CA_CERT_PATH
# so that we make sure we are pinning certs
if duo_certs is not None:
if disable_ca_pinning and duo_certs not in (None, DEFAULT_CA_CERT_PATH):
raise DuoException(
"Cannot both disable CA pinning and provide custom CA certificates"
)

self._disable_ca_pinning = disable_ca_pinning

if disable_ca_pinning:
self._duo_certs = True
elif duo_certs is not None:
if duo_certs == "DISABLE":
self._duo_certs = False
else:
Expand Down
95 changes: 95 additions & 0 deletions tests/test_setup_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from unittest.mock import patch, MagicMock
from duo_universal import client
import unittest

Expand Down Expand Up @@ -158,5 +159,99 @@ def test_proxy_set_off_kwargs(self):
self.assertEqual(client_with_no_proxy._http_proxy, NONE)


class TestDisableCaPinning(unittest.TestCase):

def test_default_is_pinning_enabled(self):
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI)
self.assertFalse(c._disable_ca_pinning)
self.assertEqual(c._duo_certs, client.DEFAULT_CA_CERT_PATH)

def test_disable_ca_pinning_true(self):
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
disable_ca_pinning=True)
self.assertTrue(c._disable_ca_pinning)
self.assertTrue(c._duo_certs)

def test_disable_ca_pinning_with_default_duo_certs(self):
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
duo_certs=client.DEFAULT_CA_CERT_PATH, disable_ca_pinning=True)
self.assertTrue(c._disable_ca_pinning)
self.assertTrue(c._duo_certs)

def test_disable_ca_pinning_with_none_duo_certs(self):
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
duo_certs=None, disable_ca_pinning=True)
self.assertTrue(c._disable_ca_pinning)
self.assertTrue(c._duo_certs)

def test_disable_ca_pinning_with_custom_duo_certs_raises(self):
with self.assertRaises(client.DuoException) as ctx:
client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
duo_certs=CA_CERT_NEW, disable_ca_pinning=True)
self.assertIn("Cannot both disable CA pinning", str(ctx.exception))

def test_disable_ca_pinning_with_disable_duo_certs_raises(self):
with self.assertRaises(client.DuoException) as ctx:
client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
duo_certs="DISABLE", disable_ca_pinning=True)
self.assertIn("Cannot both disable CA pinning", str(ctx.exception))

def test_disable_ca_pinning_false_preserves_existing_behavior(self):
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
disable_ca_pinning=False)
self.assertEqual(c._duo_certs, client.DEFAULT_CA_CERT_PATH)


class TestDisableCaPinningRequests(unittest.TestCase):

@patch('requests.post')
def test_health_check_pinning_disabled_uses_system_trust_store(self, requests_mock):
requests_mock.return_value = MagicMock(content=b'{"stat": "OK", "response": {"timestamp": 1}}')
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
disable_ca_pinning=True)
c.health_check()
_, kwargs = requests_mock.call_args
self.assertTrue(kwargs['verify'])
self.assertIsNot(kwargs['verify'], client.DEFAULT_CA_CERT_PATH)

@patch('requests.post')
def test_health_check_pinning_enabled_uses_bundled_certs(self, requests_mock):
requests_mock.return_value = MagicMock(content=b'{"stat": "OK", "response": {"timestamp": 1}}')
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI)
c.health_check()
_, kwargs = requests_mock.call_args
self.assertEqual(kwargs['verify'], client.DEFAULT_CA_CERT_PATH)

@patch('requests.post')
def test_token_exchange_pinning_disabled_uses_system_trust_store(self, requests_mock):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {'id_token': 'fake'}
requests_mock.return_value = mock_response
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI,
disable_ca_pinning=True)
try:
c.exchange_authorization_code_for_2fa_result('code', 'user')
except client.DuoException:
pass
_, kwargs = requests_mock.call_args
self.assertTrue(kwargs['verify'])
self.assertIsNot(kwargs['verify'], client.DEFAULT_CA_CERT_PATH)

@patch('requests.post')
def test_token_exchange_pinning_enabled_uses_bundled_certs(self, requests_mock):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {'id_token': 'fake'}
requests_mock.return_value = mock_response
c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI)
try:
c.exchange_authorization_code_for_2fa_result('code', 'user')
except client.DuoException:
pass
_, kwargs = requests_mock.call_args
self.assertEqual(kwargs['verify'], client.DEFAULT_CA_CERT_PATH)


if __name__ == '__main__':
unittest.main()
Loading