Skip to content

Commit

Permalink
[fix] Invalidate checksum cache on VPN server change openwisp#667
Browse files Browse the repository at this point in the history
  • Loading branch information
codesankalp committed Jul 14, 2022
1 parent d3699c2 commit ac0399c
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 3 deletions.
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3127,6 +3127,17 @@ The signal is emitted when subnets and IP addresses have been provisioned
for a ``VpnClient`` for a VPN server with a subnet with
`subnet division rule <#subnet-division-app>`_.

``vpn_server_modified``
~~~~~~~~~~~~~~~~~~~~~~~

**Path**: ``openwisp_controller.config.signals.vpn_server_modified``

**Arguments**:

- ``instance``: instance of ``Vpn``.

The signal is emitted when the VPN server is modified.

``vpn_peers_changed``
~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions openwisp_controller/config/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
device_name_changed,
group_templates_changed,
vpn_peers_changed,
vpn_server_modified,
)

# ensure Device.hardware_id field is not flagged as unique
Expand Down Expand Up @@ -229,6 +230,7 @@ def enable_cache_invalidation(self):
device_cache_invalidation_handler,
devicegroup_change_handler,
devicegroup_delete_handler,
vpn_server_change_handler,
)

post_save.connect(
Expand Down Expand Up @@ -270,6 +272,11 @@ def enable_cache_invalidation(self):
sender=self.device_model,
dispatch_uid='device.invalidate_cache',
)
vpn_server_modified.connect(
vpn_server_change_handler,
sender=self.vpn_model,
dispatch_uid='vpn.invalidate_checksum_cache',
)

def register_dashboard_charts(self):
register_dashboard_chart(
Expand Down
48 changes: 47 additions & 1 deletion openwisp_controller/config/base/vpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...base import ShareableOrgMixinUniqueName
from .. import crypto
from .. import settings as app_settings
from ..signals import vpn_peers_changed
from ..signals import vpn_peers_changed, vpn_server_modified
from ..tasks import create_vpn_dh, trigger_vpn_server_endpoint
from .base import BaseConfig

Expand Down Expand Up @@ -127,6 +127,11 @@ class Meta:
unique_together = ('organization', 'name')
abstract = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# for internal usage
self._send_vpn_modified_after_save = False

def clean(self, *args, **kwargs):
super().clean(*args, **kwargs)
self._validate_backend()
Expand Down Expand Up @@ -190,6 +195,9 @@ def save(self, *args, **kwargs):
"""
Calls _auto_create_cert() if cert is not set
"""
created = self._state.adding
if not created:
self._check_changes()
create_dh = False
if not self.cert and self.ca:
self.cert = self._auto_create_cert()
Expand All @@ -203,8 +211,35 @@ def save(self, *args, **kwargs):
super().save(*args, **kwargs)
if create_dh:
transaction.on_commit(lambda: create_vpn_dh.delay(self.id))
if not created and self._send_vpn_modified_after_save:
self._send_vpn_modified_signal()
self._send_vpn_modified_after_save = False
self.update_vpn_server_configuration()

def _check_changes(self):
attrs = [
'config',
'host',
'ca',
'cert',
'key',
'backend',
'subnet',
'ip',
'dh',
'public_key',
'private_key',
]
current = self._meta.model.objects.only(*attrs).get(pk=self.pk)
for attr in attrs:
if getattr(self, attr) == getattr(current, attr):
continue
self._send_vpn_modified_after_save = True
break

def _send_vpn_modified_signal(self):
vpn_server_modified.send(sender=self.__class__, instance=self)

@classmethod
def dhparam(cls, length):
"""
Expand Down Expand Up @@ -736,3 +771,14 @@ def _auto_ip(self):
if func(self):
return
self.ip = self.vpn.subnet.request_ip()

@classmethod
def invalidate_clients_cache(cls, vpn):
"""
Invalidate checksum cache for clients that uses this VPN server
"""
for client in vpn.vpnclient_set.iterator():
# invalidate cache for device
client.config._send_config_modified_signal(
action='related_template_changed'
)
6 changes: 6 additions & 0 deletions openwisp_controller/config/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def config_backend_change_handler(instance, **kwargs):
devicegroup_templates_change_handler(instance, **kwargs)


def vpn_server_change_handler(instance, **kwargs):
transaction.on_commit(
lambda: tasks.invalidate_vpn_server_devices_cache_change.delay(instance.id)
)


def devicegroup_templates_change_handler(instance, **kwargs):
if type(instance) is list:
# instance is queryset of devices
Expand Down
4 changes: 4 additions & 0 deletions openwisp_controller/config/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@
config_backend_changed.__doc__ = """
providing arguments: ['instance', 'backend', 'old_backend']
"""
vpn_server_modified = Signal()
vpn_server_modified.__doc__ = """
providing arguments: ['instance']
"""
8 changes: 8 additions & 0 deletions openwisp_controller/config/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def invalidate_devicegroup_cache_change(instance_id, model_name):
DeviceGroupCommonName.certificate_change_invalidates_cache(instance_id)


@shared_task(base=OpenwispCeleryTask)
def invalidate_vpn_server_devices_cache_change(vpn_pk):
Vpn = load_model('config', 'Vpn')
VpnClient = load_model('config', 'VpnClient')
vpn = Vpn.objects.get(pk=vpn_pk)
VpnClient.invalidate_clients_cache(vpn)


@shared_task(base=OpenwispCeleryTask)
def invalidate_devicegroup_cache_delete(instance_id, model_name, **kwargs):
from .api.views import DeviceGroupCommonName
Expand Down
26 changes: 24 additions & 2 deletions openwisp_controller/config/tests/test_vpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ...vpn_backends import OpenVpn
from .. import settings as app_settings
from ..signals import vpn_peers_changed
from ..signals import config_modified, vpn_peers_changed, vpn_server_modified
from ..tasks import create_vpn_dh
from .utils import (
CreateConfigTemplateMixin,
Expand Down Expand Up @@ -442,7 +442,7 @@ def test_cert_validation(self):
self.assertIn('CA is required with this VPN backend', message_dict['ca'])


class TestVpnTransaction(BaseTestVpn, TransactionTestCase):
class TestVpnTransaction(BaseTestVpn, TestWireguardVpnMixin, TransactionTestCase):
@mock.patch.object(create_vpn_dh, 'delay')
def test_create_vpn_dh_with_vpn_create(self, delay):
vpn = self._create_vpn(dh='')
Expand All @@ -463,6 +463,28 @@ def test_update_vpn_dh(self, dhparam):
self.assertNotEqual(vpn.dh, Vpn._placeholder_dh)
dhparam.assert_called_once()

def test_vpn_server_change_invalidates_device_cache(self):
device, vpn, template = self._create_wireguard_vpn_template()
with catch_signal(
vpn_server_modified
) as mocked_vpn_server_modified, catch_signal(
config_modified
) as mocked_config_modified:
vpn.host = 'localhost'
vpn.save(update_fields=['host'])
mocked_vpn_server_modified.assert_called_once_with(
signal=vpn_server_modified, sender=Vpn, instance=vpn
)
mocked_config_modified.assert_called_once_with(
signal=config_modified,
sender=Config,
instance=device.config,
previous_status='modified',
action='related_template_changed',
config=device.config,
device=device,
)


class TestWireguard(BaseTestVpn, TestWireguardVpnMixin, TestCase):
def test_wireguard_config_creation(self):
Expand Down

0 comments on commit ac0399c

Please sign in to comment.