Mini Shell
"""Core module for rules and sets managing."""
import json
from pathlib import Path
import logging
from typing import Iterable, List, Optional, Set, Tuple
from defence360agent.internals.global_scope import g
from im360.contracts.config import (
NetworkInterface,
UnifiedAccessLogger,
DOS,
EnhancedDOS,
)
from im360.internals.core.ipset.port_deny import (
InputPortBlockingDenyModeIPSet,
OutputPortBlockingDenyModeIPSet,
)
from defence360agent.utils.validate import IPVersion
from im360.internals.core.firewall import Iptables
from . import ip_versions
from .firewall import (
FirewallRules,
RuleDef,
firewall_logging_enabled,
is_nat_available,
)
from .ipset import IP_SET_PREFIX, libipset
from .ipset.country import IPSetCountry
from .ipset.ip import IPSet
from .ipset.libipset import IPSetCmdBuilder, IPSetRestoreCmd
from .ipset.port import IPSetIgnoredByPort, IPSetPort
from .ipset.redirect import (
IPSetNoRedirectPort,
IPSetWebshieldPort,
)
from .ipset.sync import IPSetSyncIPListPurpose, IPSetSyncIPListRecords
logger = logging.getLogger(__name__)
FAILED_IPSETS_FILE = "/var/imunify360/failed_ipsets_{ip_version}.json"
class RuleSet:
"""Managing iptables rules and ipsets."""
_CHAINS = [
FirewallRules.COUNTRY_WHITELIST_CHAIN,
FirewallRules.COUNTRY_BLACKLIST_CHAIN,
FirewallRules.BP_INPUT_CHAIN,
FirewallRules.LOG_BLACKLIST_CHAIN,
FirewallRules.LOG_GRAYLIST_CHAIN,
FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN,
FirewallRules.WEBSHIELD_PORTS_INPUT_CHAIN,
FirewallRules.LOG_BLOCK_PORT_CHAIN,
]
# Since DB and ipset are updated at different times,
# check relative value instead of compare absolute values.
# Use a large enough relative number to avoid false positives,
# 20% difference looks reasonable for this.
_IPSET_COUNT_TO_RECREATE_THRESHOLD = 0.2
def __init__(self):
self.entities = (
InputPortBlockingDenyModeIPSet(),
OutputPortBlockingDenyModeIPSet(),
IPSetPort(),
IPSet(),
# Order is important here,
# Ensure IPSetSyncIPListRecords is created before IPSetSyncIPListPurpose
IPSetSyncIPListRecords(),
IPSetSyncIPListPurpose(),
IPSetCountry(),
IPSetIgnoredByPort(),
IPSetNoRedirectPort(),
IPSetWebshieldPort(),
)
@staticmethod
def targets(ip_version: IPVersion) -> List[Tuple]:
"""
Returns tables & chains that Imunify360 will use in firewall management
:param ip_version: IPv4 or IPv6
:return: List[Tuple]:
"""
return [
(FirewallRules.FILTER, "INPUT"),
(
(FirewallRules.NAT, "PREROUTING")
if is_nat_available(ip_version)
else (FirewallRules.MANGLE, "PREROUTING")
),
]
@staticmethod
def _apply_ignored_interfaces(action, interface_conf, *args, **kwargs):
"""
:param interface_conf: interface configuration
:param Callable action: action to perform with interface
"""
for interface in interface_conf[NetworkInterface.DEVICE_SKIP]:
yield action(
FirewallRules.compose_rule(
FirewallRules.interface(interface),
action=FirewallRules.compose_action(FirewallRules.ACCEPT),
),
chain=FirewallRules.IMUNIFY_INPUT_CHAIN,
priority=0, # max priority for firewalld
*args,
**kwargs,
)
@staticmethod
def _compose_rule(ip_version: IPVersion, interface_conf: dict) -> RuleDef:
"""Compose rule based on NetworkInterface config"""
target_interface = interface_conf[ip_version]
action = FirewallRules.compose_action(
FirewallRules.IMUNIFY_INPUT_CHAIN
)
if target_interface:
rule = FirewallRules.compose_rule(
FirewallRules.interface(target_interface), action=action
)
else:
rule = action
return rule
async def ipset_create_commands(self, ip_version: IPVersion) -> List[str]:
names = [] # type: List[str]
for entity in self.entities:
names.extend(entity.gen_ipset_create_ops(ip_version))
return names
async def ipset_flush_commands(
self, ip_version: IPVersion, existing: Optional[Set[str]] = None
) -> Iterable[IPSetRestoreCmd]:
"""Generate ipset restore commands to destroy *existing* ipsets."""
if existing is None:
existing = await self.existing_ipsets(ip_version)
# get entity specific flush commands
cmds = []
needed_entities = [
entity
for entity in self.entities
if hasattr(entity, "gen_ipset_flush_ops")
]
for entity in needed_entities:
cmds += entity.gen_ipset_flush_ops(ip_version, existing)
return cmds
async def ipset_destroy_commands(
self, ip_version: IPVersion, existing: Optional[Set[str]] = None
) -> Iterable[IPSetRestoreCmd]:
"""Generate ipset restore commands to destroy *existing* ipsets."""
if existing is None:
existing = await self.existing_ipsets(ip_version)
# get entity specific destroy commands
cmds = {} # type: Dict[str, IPSetRestoreCmd]
for entity in self.entities:
entity_cmds = entity.gen_ipset_destroy_ops(ip_version, existing)
cmds.update(entity_cmds)
# generic destroy
for ipset_name in existing:
if ipset_name not in cmds:
# ipset is not special, remove using a generic destroy command
cmds[ipset_name] = IPSetCmdBuilder.get_destroy_cmd(ipset_name)
return cmds.values()
async def create_commands(
self, firewall, interface_conf: dict, ip_version: IPVersion
) -> list:
"""Return a list of firewall commands to create all required rules."""
actions = []
# input chains
for table, chain in self.targets(ip_version):
# main chain and rule
actions.extend(
[
firewall.create_chain(
table=table, chain=FirewallRules.IMUNIFY_INPUT_CHAIN
),
firewall.insert_rule(
self._compose_rule(ip_version, interface_conf),
table=table,
chain=chain,
),
*self._apply_ignored_interfaces(
firewall.insert_rule, interface_conf, table=table
),
]
)
# subchains
actions.extend(
[
firewall.create_chain(table=FirewallRules.FILTER, chain=chain)
for chain in self._CHAINS
]
)
# log block rules
actions.extend(self._log_block_rules(firewall.append_rule, ip_version))
# output chains
# main chain and rule
actions.extend(
[
firewall.create_chain(
table=FirewallRules.FILTER,
chain=FirewallRules.IMUNIFY_OUTPUT_CHAIN,
),
firewall.insert_rule(
FirewallRules.compose_action(
FirewallRules.IMUNIFY_OUTPUT_CHAIN
),
chain="OUTPUT",
),
]
)
# subchains
actions.extend(
[
firewall.create_chain(table=FirewallRules.FILTER, chain=chain)
for chain in [FirewallRules.BP_OUTPUT_CHAIN]
]
)
# ipsets rules (can be in NAT or FILTER table)
actions.extend(
[
firewall.append_rule(**rule)
for rule in await self._collect_ipset_rules(ip_version)
]
)
if DOS.ENABLED or EnhancedDOS.ENABLED:
# Add connection tracking rule.
actions.append(
firewall.insert_rule(
# fmt: off
(
"-m", "comment",
"--comment", '"Connection tracking for Imunify360."',
"-j", "CT",
),
# fmt: off
table="raw", chain="PREROUTING"
)
)
return actions
def destroy_commands(
self, firewall, interface_conf: dict, ip_version: IPVersion
) -> Iterable[list]:
"""Returns an iterable over list of commands to destroy firewall rules.
Each list should be executed as a separate firewall commit
operation."""
# input chains
for table, chain in self.targets(ip_version):
# delete main rule
yield [
firewall.delete_rule(
self._compose_rule(ip_version, interface_conf),
table=table,
chain=chain,
)
]
yield [
firewall.flush_chain(
FirewallRules.IMUNIFY_INPUT_CHAIN, table=table
),
firewall.delete_chain(
FirewallRules.IMUNIFY_INPUT_CHAIN, table=table
),
]
for chain in self._CHAINS:
yield [
firewall.flush_chain(chain, table=FirewallRules.FILTER),
firewall.delete_chain(chain, table=FirewallRules.FILTER),
]
# output chains
# delete main rule
yield [
firewall.delete_rule(
FirewallRules.compose_action(
FirewallRules.IMUNIFY_OUTPUT_CHAIN
),
chain="OUTPUT",
)
]
# flush and delete main chain
yield [
firewall.flush_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN),
firewall.delete_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN),
]
for chain in [FirewallRules.BP_OUTPUT_CHAIN]:
yield [firewall.flush_chain(chain), firewall.delete_chain(chain)]
# Delete connection tracking rule.
yield [
firewall.delete_rule(
# fmt: off
(
"-m", "comment",
"--comment", '"Connection tracking for Imunify360."',
"-j", "CT",
),
# fmt: off
table="raw",
chain="PREROUTING",
)
]
def required_ipsets(self, ip_version: IPVersion) -> Set[str]:
names = set() # type: Set[str]
for entity in self.entities:
names.update(entity.get_all_ipsets(ip_version))
return names
async def check_commands(
self, firewall: Iptables, interface_conf: dict, ip_version: IPVersion
) -> list:
"""Returns a list of firewall commands to check for firewall rules."""
actions = []
for table, chain in self.targets(ip_version):
actions.extend(
[
firewall.has_rule(
self._compose_rule(ip_version, interface_conf),
table=table,
chain=chain,
),
*self._apply_ignored_interfaces(
firewall.has_rule, interface_conf, table=table
),
]
)
actions.extend(self._log_block_rules(firewall.has_rule, ip_version))
actions.extend(
[
firewall.has_rule(
FirewallRules.compose_action(
FirewallRules.IMUNIFY_OUTPUT_CHAIN
),
table=FirewallRules.FILTER,
chain="OUTPUT",
),
]
)
actions.extend(
[
firewall.has_rule(**rule)
for rule in await self._collect_ipset_rules(ip_version)
]
)
if DOS.ENABLED or EnhancedDOS.ENABLED:
actions.append(
firewall.has_rule(
# fmt: off
(
"-m", "comment",
"--comment", '"Connection tracking for Imunify360."',
"-j", "CT",
),
# fmt: off
table="raw", chain="PREROUTING"
)
)
return actions
def _log_block_rules(self, predicate, ip_version: IPVersion):
rules = []
for chain, prefix, action in (
(
FirewallRules.LOG_BLACKLIST_CHAIN,
UnifiedAccessLogger.BLACKLIST,
FirewallRules.compose_action(FirewallRules.DROP),
),
(
FirewallRules.LOG_GRAYLIST_CHAIN,
UnifiedAccessLogger.GRAYLIST,
FirewallRules.compose_action(FirewallRules.DROP),
),
(
FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN,
UnifiedAccessLogger.BLACKLIST_COUNTRY,
FirewallRules.compose_action(FirewallRules.DROP),
),
(
FirewallRules.LOG_BLOCK_PORT_CHAIN,
UnifiedAccessLogger.BLOCKED_BY_PORT,
FirewallRules.compose_action(FirewallRules.REJECT),
),
):
# At the moment, stateful packets processing is enabled
# for blacklisted countries only.
stateful = chain == FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN
rules.extend(
predicate(rule, table=FirewallRules.FILTER, chain=chain)
for rule in self._log_drop_rules(
ip_version, prefix, action, stateful
)
)
return rules
async def _collect_ipset_rules(self, ip_version: IPVersion) -> List[dict]:
rules = [] # type: List[dict]
for entity in self.entities:
rules.extend(entity.get_rules(ip_version))
rules.sort(key=lambda r: (r["chain"], r["priority"]))
return rules
async def fill_ipsets(
self, ip_version: IPVersion, missing: Set[str]
) -> None:
"""Fills all ipsets with required elements."""
create_and_restore_cmds = []
for entity in self.entities:
for ip_set in entity.get_all_ipset_instances(ip_version):
if ip_set.gen_ipset_name_for_ip_version(ip_version) in missing:
create_and_restore_cmds.extend(
ip_set.gen_ipset_create_ops(ip_version)
)
create_and_restore_cmds.extend(
await ip_set.gen_ipset_restore_ops(ip_version)
)
await libipset.restore(create_and_restore_cmds)
logger.info("IP sets content restored from database")
@staticmethod
async def existing_ipsets(ip_version: IPVersion) -> Set[str]:
prefix = ".".join([IP_SET_PREFIX, ip_version])
return set(
s for s in await libipset.list_set() if s.startswith(prefix)
)
async def _flush_ipsets(self, to_flush: set[str], ip_version: IPVersion):
logger.info("Flushing ipsets: %s", to_flush)
try:
await libipset.restore(
await self.ipset_flush_commands(ip_version, to_flush)
)
except libipset.IPSetNotFoundError:
logger.warning(
"Failed to flush ipsets: %s",
", ".join(to_flush),
)
def has_ipset_to_destroy(
self, ip_version: IPVersion, existing: set[str] | None
) -> bool:
if existing is None:
return False
prev_failed = self._get_prev_failed(ip_version)
return bool(existing - prev_failed)
def ipsets_to_refill(
self, ip_version: IPVersion, existing: set[str], required: set[str]
) -> set[str]:
"""Check if ipsets need to be refilled."""
prev_failed = self._get_prev_failed(ip_version)
return existing & prev_failed & required
async def destroy_ipsets(
self,
ip_version: IPVersion,
existing: set[str] | None = None,
force: bool = False,
) -> None:
"""Destroys ipsets with given names.
Args:
ip_version: IP version to destroy ipsets for
existing: Set of ipsets to destroy. If None, all existing ipsets will be destroyed
force: If True, ignore previously failed ipsets and try to destroy them again
"""
logger.info(
"Destroying ipsets for %s existing: %s force: %s",
ip_version,
existing,
force,
)
to_destroy = (
existing.copy()
if existing is not None
else await self.existing_ipsets(ip_version)
)
await self._flush_ipsets(to_destroy, ip_version)
prev_failed = self._get_prev_failed(ip_version) if not force else set()
prev_failed -= await self._sets_without_references(prev_failed)
failed_ipsets = await self._destroy_ipsets_group(
to_destroy, prev_failed, ip_version
)
failed_ipsets = await self._destroy_ipsets_one_by_one(failed_ipsets)
if failed_ipsets:
ipset_with_members = {
ipset: await libipset.get_ipset_members(ipset)
for ipset in failed_ipsets
}
references = {
ipset: await libipset.get_ipset_references(ipset)
for ipset in failed_ipsets
}
logger.error(
"Failed to destroy ipsets: %s",
", ".join(
f"{ipset=}: {members=} refs: {references[ipset]}"
for ipset, members in ipset_with_members.items()
),
)
self._save_failed_ipsets(failed_ipsets | prev_failed, ip_version)
def clean_previously_failed_ipsets(self, ip_version: IPVersion) -> None:
"""Clean previously failed ipsets from the file."""
Path(FAILED_IPSETS_FILE.format(ip_version=ip_version)).unlink(
missing_ok=True
)
async def _sets_without_references(self, ipsets: set[str]) -> set[str]:
"""Return a set of ipsets that have no references (safe to destroy)."""
res = set()
for ipset in ipsets:
if await libipset.get_ipset_references(ipset) == 0:
res.add(ipset)
return res
def _get_prev_failed(self, ip_version: IPVersion) -> set[str]:
prev_failed = set()
if Path(FAILED_IPSETS_FILE.format(ip_version=ip_version)).exists():
try:
with open(
FAILED_IPSETS_FILE.format(ip_version=ip_version), "r"
) as f:
prev_failed = set(json.load(f))
except (json.JSONDecodeError, OSError):
logger.error(
"Failed to read or parse dump file: %s", FAILED_IPSETS_FILE
)
if g.get("DEBUG"):
logger.info("Previous failed ipsets: %s", prev_failed)
return prev_failed
async def _destroy_ipsets_group(
self,
to_destroy: set[str],
prev_failed: set[str],
ip_version: IPVersion,
) -> set[str]:
max_tries = 3
attempt = 0
while to_destroy and attempt < max_tries:
to_destroy -= prev_failed
to_destroy &= await self.existing_ipsets(ip_version)
try:
await libipset.restore(
await self.ipset_destroy_commands(ip_version, to_destroy)
)
return set()
except (
libipset.IPSetNotFoundError,
libipset.IPSetCannotBeDestroyedError,
):
attempt += 1
logger.warning(
"Failed to destroy ipsets: %s, retrying: %s",
", ".join(to_destroy),
attempt,
)
# return failed to destroy ipsets
return to_destroy
async def _destroy_ipsets_one_by_one(
self, to_destroy: set[str]
) -> set[str]:
if g.get("DEBUG"):
logger.info("Destroying ipsets: %s", to_destroy)
failed_ipsets = set()
for ipset_name in to_destroy:
try:
await libipset.restore(
[IPSetCmdBuilder.get_destroy_cmd(ipset_name)]
)
except libipset.IPSetNotFoundError:
# If ipset doesn't exist, we can consider it destroyed
pass
except libipset.IPSetCannotBeDestroyedError as e:
logger.warning(
"Failed to destroy ipset %s: %s", ipset_name, str(e)
)
failed_ipsets.add(ipset_name)
return failed_ipsets
def _save_failed_ipsets(
self, failed_ipsets: set[str], ip_version: IPVersion
):
logger.info(
"Saving failed ipsets: %s to file: %s",
failed_ipsets,
FAILED_IPSETS_FILE.format(ip_version=ip_version),
)
try:
with open(
FAILED_IPSETS_FILE.format(ip_version=ip_version), "w"
) as f:
json.dump(list(failed_ipsets), f)
except Exception as e:
logger.error("Failed to save failed ipsets to file: %s", e)
async def _recreate_ipsets(
self, ip_version: IPVersion, existing: Optional[Set[str]] = None
):
"""Reset all ipsets, create them again and fill with IPs
for given ip version."""
for entity in self.entities:
await entity.reset(ip_version, existing)
async def recreate_ipsets(
self, ip_version: IPVersion = None, existing: Optional[Set[str]] = None
):
"""Recreate existing ipsets (or given).
If *ip_version* is None, recreate ipsets for all enabled ip versions.
"""
if ip_version:
await self._recreate_ipsets(ip_version, existing)
else:
for ip_version in ip_versions.enabled():
await self._recreate_ipsets(ip_version, existing)
@staticmethod
def _log_drop_rules(ip_version: IPVersion, prefix, action, stateful: bool):
rules = []
if stateful:
rules.append(
# fmt: off
(
"-m", "conntrack",
"--ctstate", "ESTABLISHED,RELATED",
"-j", "ACCEPT",
),
# fmt: on
)
if firewall_logging_enabled():
rules.append(
FirewallRules.compose_rule(
action=FirewallRules.nflog_action(
group=FirewallRules.nflog_group(ip_version),
prefix=prefix,
)
)
)
rules.append(action)
return rules
async def get_outdated_ipsets(self, ip_version: IPVersion) -> list:
"""
Return list of ipsets the contents of which do not match the database
"""
outdated: list = []
for entity in self.entities:
all_ipsets = await entity.get_ipsets_count(ip_version)
outdated.extend(
ipset
for ipset in all_ipsets
if ipset.ipset_count != ipset.db_count
)
return outdated