Mini Shell
"""Define ipsets/iptables rules for DEF-16003 synclists.
i.e. for ip network blocks/unblocks received from /api/sync/v1/iplist
correlation server endpoint.
The expected behavior (requirements) is defined in
src/handbook/message_processing/server_sync.py
"""
import dataclasses
import ipaddress
import itertools
import re
from abc import ABCMeta, abstractmethod
from logging import getLogger
from typing import (
AbstractSet,
Dict,
FrozenSet,
Iterable,
Iterator,
List,
Set,
TypeVar,
)
from im360.contracts.config import Webshield as WebshieldConfig
from im360.internals.core import FirewallRules, rules
from im360.internals.core.ipset import (
IPSetAtomicRestoreBase,
IPSetCollectionResetMixin,
libipset,
)
from im360.internals.core.ipset.libipset import (
IPSetCmdBuilder,
IPSetRestoreCmd,
)
from im360.internals.core.rules import FirewallRule
from im360.model.firewall import IPListID, IPListPurpose, IPListRecord, Purpose
from im360.subsys import webshield
from defence360agent.utils.validate import IP, IPVersion
from . import IP_SET_PREFIX, AbstractIPSet, IPSetCount, get_ipset_family
logger = getLogger(__name__)
__all__ = [
"IPSetSyncCaptcha",
"IPSetSyncDrop",
"IPSetSyncIPListPurpose",
"IPSetSyncSplashscreen",
"IPSetSyncWhite",
]
Args = TypeVar("Args")
T = TypeVar("T")
class IPSetSync(IPSetAtomicRestoreBase, metaclass=ABCMeta):
"""Abstract class responsible for iptables rules for *sync ipsets.
It mimics .ip.BaseIPSet just enough for .ip.IPSet.get_rules() to work.
"""
DB_NAME = None # not in iplist table
@property
@abstractmethod
def purpose(self) -> Purpose:
raise NotImplementedError # pragma: no cover
def create_rules(self, ip_version: IPVersion) -> Iterable[dict]:
"""To be called by .ip.IPSet.get_rules()"""
ipset_name = self.gen_ipset_name_for_ip_version(ip_version)
return map(
dataclasses.asdict, self._rules(ipset_name, ip_version=ip_version)
)
@abstractmethod
def _rules(
self, ipset_name: str, ip_version: IPVersion
) -> Iterator[FirewallRule]:
raise NotImplementedError # pragma: no cover
def is_enabled(self, ip_version: IPVersion = None) -> bool:
"""Whether it makes sense to call *create_rules*."""
return True
def gen_ipset_name_for_ip_version(self, ip_version: IPVersion) -> str:
return (
self.custom_ipset_name
or f"{IP_SET_PREFIX}.{ip_version}.{self.purpose}.sync"
)
async def get_db_count(self, ip_version: IPVersion) -> int:
""".ip.BaseIPSet method."""
# see IPSetSyncIPListPurpose.get_ipsets_count()
return 0 # stub
def gen_ipset_create_ops(self, ip_version: IPVersion) -> List[str]:
size = max(32, 2 * len(IPSetSyncIPListRecords.all_ipsets(ip_version)))
name = self.gen_ipset_name_for_ip_version(ip_version)
return [IPSetCmdBuilder.get_create_list_set_cmd(name, size=size)]
def gen_ipset_destroy_ops(self, ip_version: IPVersion) -> List[str]:
ipset_name = self.gen_ipset_name_for_ip_version(ip_version)
return [IPSetCmdBuilder.get_destroy_cmd(ipset_name)]
def gen_ipset_flush_ops(self, ip_version: IPVersion) -> List[str]:
return [
IPSetCmdBuilder.get_flush_cmd(
self.gen_ipset_name_for_ip_version(ip_version)
)
]
async def gen_ipset_restore_ops(self, ip_version: IPVersion) -> List[str]:
ipset_name = self.gen_ipset_name_for_ip_version(ip_version)
result: List[str] = []
for ipset in IPSetSyncIPListRecords().get_ipset_instances_by_purpose(
ip_version, [self.purpose]
):
if not await ipset.exists(ip_version):
result.extend(ipset.gen_ipset_create_ops(ip_version))
result.extend(await ipset.gen_ipset_restore_ops(ip_version))
result.append(
IPSetCmdBuilder.get_add_cmd(
ipset_name, ipset.gen_ipset_name_for_ip_version(ip_version)
)
)
return result
class IPSetSyncCaptcha(IPSetSync):
"""Responsible for iptables rules related to *captcha.sync ipsets."""
purpose = Purpose.CAPTCHA
def is_enabled(self, ip_version: IPVersion = None) -> bool:
return super().is_enabled() and _captcha_rules_enabled()
def _rules(
self, ipset_name: str, ip_version: IPVersion
) -> Iterator[FirewallRule]:
assert self.is_enabled()
return rules.webshield_rules(
ipset_name,
ip_version,
rules.CaptchaRuleBuilder(include_webshield_ports_rules=True),
)
class IPSetSyncDrop(IPSetSync):
"""Responsible for iptables rules related to *drop.sync ipsets."""
purpose = Purpose.DROP
PRIORITY = FirewallRules.DROP_SYNC_PRIORITY
def _rules(
self, ipset_name: str, ip_version: IPVersion
) -> Iterator[FirewallRule]:
assert self.is_enabled()
return rules.drop_rules(ipset_name, ip_version, priority=self.PRIORITY)
class IPSetSyncSplashscreen(IPSetSync):
"""Responsible for iptables rules related to *splashscreen.sync ipsets."""
purpose = Purpose.SPLASHSCREEN
def is_enabled(self, ip_version: IPVersion = None) -> bool:
return super().is_enabled() and _splashscreen_rules_enabled()
def _rules(
self, ipset_name: str, ip_version: IPVersion
) -> Iterator[FirewallRule]:
assert self.is_enabled()
return rules.webshield_rules(
ipset_name,
ip_version,
rules.SplashscreenRuleBuilder(),
)
class IPSetSyncWhite(IPSetSync):
"""Responsible for iptables rules related to *white.sync ipsets."""
purpose = Purpose.WHITE
def _rules(
self, ipset_name: str, ip_version: IPVersion
) -> Iterator[FirewallRule]:
assert self.is_enabled()
return rules.white_rules(ipset_name, ip_version)
class IPSetSyncIPListPurpose(AbstractIPSet):
"""Facade to manage /api/sync remote iplist ipsets."""
def __init__(self):
super().__init__()
self.ip_sets = [
IPSetSyncWhite(),
IPSetSyncDrop(),
IPSetSyncSplashscreen(),
IPSetSyncCaptcha(),
]
def get_all_ipsets(self, ip_version: IPVersion) -> FrozenSet[str]:
"""Return required [system] "sync" ipset names.
It does not check whether the ipsets are actually present on
the system.
"""
return frozenset(
[
_ipset.gen_ipset_name_for_ip_version(ip_version)
for _ipset in self.get_all_ipset_instances(ip_version)
]
)
def get_all_ipset_instances(
self, ip_version: IPVersion = None
) -> List[IPSetSync]:
return self.ip_sets
def get_enabled_ipset_instances(self) -> List[IPSetSync]:
return [
_ipset
for _ipset in self.get_all_ipset_instances()
if _ipset.is_enabled()
]
def get_ipset_instances_by_purpose(self, purpose: str) -> IPSetSync:
return next(
filter(
lambda set_instance: set_instance.purpose == purpose,
self.ip_sets,
)
)
@staticmethod
def _purpose_from_ipset_name(ipset_name: str) -> Purpose:
return next(
purpose for purpose in Purpose if f".{purpose}." in ipset_name
)
def get_rules(self, ip_version: IPVersion, **kwargs) -> Iterable[dict]:
ruleset = []
for set_ in self.get_enabled_ipset_instances():
ruleset.extend(set_.create_rules(ip_version))
return ruleset
async def restore(self, ip_version: IPVersion) -> None:
"""Restore system ipsets from db."""
# Define what actual [system] ipsets are created here instead
# of in IPSetSync{White,Drop,Splashscreen,Captcha}, to get
# greater flexibility in how the ipsets flushed/created/filled
# e.g., to control the impact on memory consumption
existing_ipsets = frozenset(await libipset.list_set())
flush_sync_cmds = self._gen_flush_sync_cmds(
ip_version, existing_ipsets
)
restore_cmds = (
self._gen_ipset_create_cmds(ip_version, exclude=existing_ipsets),
flush_sync_cmds,
# commands to fill sync ipsets
await self._fill_cmds(ip_version),
)
await libipset.restore(itertools.chain(*restore_cmds))
def _gen_flush_sync_cmds(self, ip_version, existing_ipsets):
# commands to flush sync [list:set] ipsets
# if they exist regardless of config settings
return (
IPSetCmdBuilder.get_flush_cmd(ipset_name)
for ipset_name in self.get_all_ipsets(ip_version)
if ipset_name in existing_ipsets
)
def gen_ipset_flush_ops(self, ip_version, existing_ipsets):
return self._gen_flush_sync_cmds(ip_version, existing_ipsets)
def gen_ipset_create_ops(
self, ip_version: IPVersion
) -> List[IPSetRestoreCmd]:
"""`ipset`'s commands to create remote iplist ipsets."""
return self._gen_ipset_create_cmds(ip_version)
def gen_ipset_destroy_ops(
self, ip_version: IPVersion, existing: Set[str]
) -> Dict[str, IPSetRestoreCmd]:
"""
`ipset`'s commands to destroy both sync & id remote iplist ipsets.
"""
destroy_sync_cmds = {
ipset_name: IPSetCmdBuilder.get_destroy_cmd(ipset_name)
for ipset_name in self.get_all_ipsets(ip_version)
if ipset_name in existing
}
return destroy_sync_cmds
def _gen_ipset_create_cmds(
self, ip_version: IPVersion, *, exclude: AbstractSet = frozenset()
) -> List[IPSetRestoreCmd]:
"""Return mapping: ipset -> command to create it.
Exclude ipsets mention in the *exclude* set.
"""
result = []
for ipset in self.get_all_ipset_instances(ip_version):
if ipset.gen_ipset_name_for_ip_version(ip_version) not in exclude:
result.extend(ipset.gen_ipset_create_ops(ip_version))
return result
async def get_ipsets_count(
self, ip_version: IPVersion
) -> List[IPSetCount]:
"""Expected vs. actual ipset member counts for all ipsets."""
# Define the method here instead of IPSetSync's subclasses, to
# avoid "one class--one ipset" restriction (to be able to
# include "iplist_id" ipsets easily if necessary)
return [
IPSetCount(
name=ipset_name,
# expected number of members in the ipset
# (db is the source of truth)
db_count=(await self._get_db_count(ipset_name, ip_version)),
# actual ipset member count as reported by ipset command
ipset_count=(await libipset.get_ipset_count(ipset_name)),
)
for ipset_name in self.get_all_ipsets(ip_version)
]
async def _get_db_count(
self, ipset_name: str, ip_version: IPVersion
) -> int:
purpose: Purpose = self._purpose_from_ipset_name(ipset_name)
return IPListPurpose.fetch_count(ip_version, purpose)
async def _fill_cmds(
self, ip_version: IPVersion
) -> Iterable[IPSetRestoreCmd]:
"""Generate `ipset restore` commands to fill sync ipsets."""
result = []
for ipset in self.get_all_ipset_instances(ip_version):
result.extend(await ipset.gen_ipset_restore_ops(ip_version))
return result
async def add_id_iplist(self, purpose, iplist_id, ip_version: IPVersion):
"""Add existing .id ipset to .sync list:set ipset"""
set_name = self.get_ipset_instances_by_purpose(
purpose
).gen_ipset_name_for_ip_version(ip_version)
iplist_id_name = SingleIPSetSyncIPListRecord(
iplist_id
).gen_ipset_name_for_ip_version(ip_version)
await libipset.restore(
(IPSetCmdBuilder.get_add_cmd(set_name, iplist_id_name),)
)
async def delete_id_iplist(
self, purpose, iplist_id, ip_version: IPVersion
):
"""Remove .id ipset from .sync list:set ipset
(without deleting .id ipset itself)
"""
ipset_name = self.get_ipset_instances_by_purpose(
purpose
).gen_ipset_name_for_ip_version(ip_version)
iplist_id_name = SingleIPSetSyncIPListRecord(
iplist_id
).gen_ipset_name_for_ip_version(ip_version)
await libipset.delete_item(ipset_name, iplist_id_name)
class SingleIPSetSyncIPListRecord(IPSetAtomicRestoreBase):
_NAME = "{prefix}.{ip_version}.{iplist_id}.id"
def __init__(self, iplist_id: IPListID):
super().__init__(iplist_id)
self.iplist_id = iplist_id
def gen_ipset_name_for_ip_version(self, ip_version: IPVersion) -> str:
return self.custom_ipset_name or self._NAME.format(
prefix=IP_SET_PREFIX,
ip_version=ip_version,
iplist_id=self.iplist_id,
)
def gen_ipset_create_ops(self, ip_version: IPVersion) -> List[str]:
name = self.gen_ipset_name_for_ip_version(ip_version)
return [
IPSetCmdBuilder.get_create_cmd(
name, get_ipset_family(ip_version), maxelem=2000_000
)
]
def gen_ipset_destroy_ops(self, ip_version: IPVersion) -> List[str]:
return [
IPSetCmdBuilder.get_destroy_cmd(
self.gen_ipset_name_for_ip_version(ip_version)
)
]
def gen_ipset_flush_ops(self, ip_version: IPVersion) -> List[str]:
return [
IPSetCmdBuilder.get_flush_cmd(
self.gen_ipset_name_for_ip_version(ip_version)
)
]
async def gen_ipset_restore_ops(self, ip_version: IPVersion) -> List[str]:
return [
IPSetCmdBuilder.get_add_cmd(
self.gen_ipset_name_for_ip_version(ip_version),
str(ip),
)
for ip in IPListRecord.fetch_ips(ip_version, self.iplist_id)
]
class IPSetSyncIPListRecords(IPSetCollectionResetMixin):
"""Namespace for ipsets populated by ips from IPListRecord table.
The table stores ips for the remote iplists.
"""
def get_all_ipsets(self, ip_version: IPVersion) -> FrozenSet[str]:
"""Return required [system] "sync" ipset records.
It does not check whether the ipsets are actually present on
the system.
"""
return frozenset(self.all_ipsets(ip_version))
@staticmethod
def all_ipsets(
ip_version: IPVersion, *, purposes: Iterable[Purpose] = Purpose
) -> FrozenSet[str]:
"""Yield all ipset names from db."""
return frozenset(
IPSetSyncIPListRecords._name_from_id(iplist_id, ip_version)
for iplist_id in IPListPurpose.fetch_iplist_ids(
ip_version, purposes
)
)
def get_ipset_instances_by_purpose(
self, ip_version: IPVersion, purpose: Iterable
):
return [
SingleIPSetSyncIPListRecord(iplist_id)
for iplist_id in IPListPurpose.fetch_iplist_ids(
ip_version, purpose
)
]
def get_all_ipset_instances(
self, ip_version: IPVersion
) -> List[IPSetAtomicRestoreBase]:
return self.get_ipset_instances_by_purpose(ip_version, Purpose)
@staticmethod
def _name_from_id(iplist_id: IPListID, ip_version: IPVersion) -> str:
return SingleIPSetSyncIPListRecord(
iplist_id
).gen_ipset_name_for_ip_version(ip_version)
@staticmethod
def match_ipset_name(ipset_name: str, ip_version: IPVersion) -> bool:
"""Whether *ipset_name* looks like an iplist_id ipset name."""
id_ph = "IPListIDplaceholder" # should not contain any re chars
name_ph = IPSetSyncIPListRecords._name_from_id(id_ph, ip_version) # type: ignore # noqa: E501
return re.fullmatch(re.escape(name_ph).replace(id_ph, r"\d+"), ipset_name) # type: ignore # noqa: E501
@staticmethod
def create_cmds(
ip_version: IPVersion, *, exclude: AbstractSet[str] = frozenset()
) -> Iterator[IPSetRestoreCmd]:
"""Yield `ipset restore` commands to create ipsets for remote ips.
Don't create ipsets with names from the *exclude* set.
"""
return (
IPSetCmdBuilder.get_create_cmd(
ipset_name, get_ipset_family(ip_version), maxelem=2000_000
)
for ipset_name in IPSetSyncIPListRecords.all_ipsets(ip_version)
if ipset_name not in exclude
)
@staticmethod
def fill_cmds(
ip_version: IPVersion,
) -> Iterator[IPSetRestoreCmd]:
"""Yield `ipset restore` commands to populate ipsets for remote ips."""
return (
IPSetCmdBuilder.get_add_cmd(
IPSetSyncIPListRecords._name_from_id(iplist_id, ip_version),
str(ip),
)
for iplist_id in IPListPurpose.fetch_iplist_ids(
ip_version, Purpose
)
for ip in IPListRecord.fetch_ips(ip_version, iplist_id)
)
@staticmethod
async def get_ipsets_count(ip_version: IPVersion) -> List[IPSetCount]:
"""Expected vs. actual ipset member counts for .id ipsets."""
return [
IPSetCount(
name=(
ipset_name := IPSetSyncIPListRecords._name_from_id(
iplist_id, ip_version
)
),
# expected number of members in the ipset
# (db is the source of truth)
db_count=IPListRecord.fetch_ips_count(ip_version, iplist_id),
# actual ipset member count as reported by ipset command
ipset_count=(await libipset.get_ipset_count(ipset_name)),
)
for iplist_id in IPListPurpose.fetch_iplist_ids(
ip_version, Purpose
)
]
def _lines_to_restore(self, iplist_id, ips, *, cmd_creator):
for ip in ips:
version = ipaddress.ip_network(ip).version
set_name = self._name_from_id(iplist_id, f"ipv{version}")
yield cmd_creator(set_name, ip)
async def add_ips(self, iplist_id, ips):
await libipset.restore(
self._lines_to_restore(
iplist_id, ips, cmd_creator=IPSetCmdBuilder.get_add_cmd
)
)
async def delete_ips(self, iplist_id, ips):
await libipset.restore(
self._lines_to_restore(
iplist_id, ips, cmd_creator=IPSetCmdBuilder.get_delete_cmd
)
)
async def restore(self, ip_version: IPVersion) -> None:
"""Restore system ipsets from db."""
existing_ipsets = frozenset(await libipset.list_set())
flush_id_cmds = self._gen_flush_id_cmds(ip_version, existing_ipsets)
restore_cmds = (
# commands to create iplist_id ipsets if necessary
self.create_cmds(ip_version, exclude=existing_ipsets),
# commands to fill iplist_id ipsets if necessary
flush_id_cmds,
self.fill_cmds(ip_version),
)
await libipset.restore(itertools.chain(*restore_cmds))
def _gen_flush_id_cmds(self, ip_version, existing_ipsets):
# commands to flush iplist_id ipsets (with ips)
return (
IPSetCmdBuilder.get_flush_cmd(ipset_name)
for ipset_name in existing_ipsets
if self.match_ipset_name(ipset_name, ip_version)
)
def gen_ipset_flush_ops(self, ip_version, existing_ipsets):
return self._gen_flush_id_cmds(ip_version, existing_ipsets)
@staticmethod
async def create(iplist_id, version):
"""create_id_ipset (if it doesn't exist)"""
command = IPSetCmdBuilder.get_create_cmd(
IPSetSyncIPListRecords._name_from_id(
iplist_id, IP.V4 if version == 4 else IP.V6
),
get_ipset_family(IP.V4 if version == 4 else IP.V6),
maxelem=2000_000,
)
await libipset.restore((command,))
async def delete(self, iplist_id, ip_version: IPVersion):
"""delete_id_ipset"""
set_name = self._name_from_id(iplist_id, ip_version)
await libipset.delete_set(set_name)
async def flush_ips(self, iplist_id, ip_version: IPVersion):
set_name = self._name_from_id(iplist_id, ip_version)
await libipset.flush_set(set_name)
def gen_ipset_destroy_ops(
self, ip_version: IPVersion, existing: Set[str]
) -> Dict[str, IPSetRestoreCmd]:
"""
`ipset`'s commands to destroy both sync & id remote iplist ipsets.
"""
destroy_id_cmds = {
ipset_name: IPSetCmdBuilder.get_destroy_cmd(ipset_name)
for ipset_name in existing
if self.match_ipset_name(ipset_name, ip_version)
}
return destroy_id_cmds
def get_rules(
self, ip_version: IPVersion, **kwargs: Args
) -> Iterable[dict]:
"""Yield remote iplist firewall rules."""
# Define iptables rules in
# IPSetSync{White,Drop,Splashscreen,Captcha} instead of here,
# to allow both old/new ipsets be active at the same time: old
# graylist/graysplashlist ipsets may contain local/non-server,
# therefore new "sync" ipset that contain server-only values
# can't replace them
return () # delegate actual rules to the above IPSetSync* classes
def _captcha_rules_enabled() -> bool:
return WebshieldConfig.ENABLE and webshield.expects_traffic()
def _splashscreen_rules_enabled() -> bool:
return _captcha_rules_enabled() and WebshieldConfig.SPLASH_SCREEN