Skip to content

Commit

Permalink
Add WIP support to add route
Browse files Browse the repository at this point in the history
  • Loading branch information
thomascellerier committed Dec 16, 2024
1 parent 95e2129 commit d714726
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 3 deletions.
46 changes: 45 additions & 1 deletion src/aiortnetlink/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import socket
import sys
from typing import Literal

Expand Down Expand Up @@ -96,6 +97,17 @@ def parse_args() -> argparse.Namespace:
action="store_true",
)

# route add
route_add_parser = route_subparsers.add_parser("add", aliases=["a"])
route_add_parser.add_argument("DESTINATION")
route_add_parser.add_argument("-4", "--ipv4", action="store_true")
route_add_parser.add_argument("-6", "--ipv6", action="store_true")
route_add_destination_group = route_add_parser.add_mutually_exclusive_group(
required=True
)
route_add_destination_group.add_argument("--via", "--gateway")
route_add_destination_group.add_argument("--dev", "--oif")

# rule
rule_parser = subparsers.add_parser("rule", aliases=["ru"])
rule_subparsers = rule_parser.add_subparsers(
Expand Down Expand Up @@ -329,7 +341,7 @@ async def run(args: argparse.Namespace) -> None:
proto_id_to_name = _rt_protos(args)
scope_id_to_name = _rt_scopes(args)

get_address = ipaddress.ip_address(args.ADDRESS)
get_address = ipaddress.ip_interface(args.ADDRESS)
async with NetlinkClient(**client_args) as nl:
link_index_to_name = await _link_index_to_name(args, nl)

Expand All @@ -343,6 +355,38 @@ async def run(args: argparse.Namespace) -> None:
)
)

case argparse.Namespace(object="route" | "ro" | "r", command="add" | "a"):
import ipaddress

async with NetlinkClient(**client_args) as nl:
if args.dev:
oif = (await nl.get_link(ifi_name=args.dev)).index
if oif is None:
raise NetlinkError(f"No such device {args.dev}")
else:
oif = None

if args.ipv4:
family = socket.AF_INET
elif args.ipv6:
family = socket.AF_INET6
else:
family = None

destination = args.DESTINATION
if "/" in destination:
destination = ipaddress.ip_network(destination)
else:
address = ipaddress.ip_address(destination)
destination = ipaddress.ip_network((address, address.max_prefixlen))

await nl.add_route(
destination=destination,
gateway=ipaddress.ip_address(args.via) if args.via else None,
oif=oif,
family=family,
)

case argparse.Namespace(object="rule" | "ru", command="show" | "s"):
if not args.numeric:
from aiortnetlink.rtfile import parse_rt_tables
Expand Down
26 changes: 25 additions & 1 deletion src/aiortnetlink/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
)

if TYPE_CHECKING:
from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from types import TracebackType
from typing import AsyncIterator, Self

Expand Down Expand Up @@ -254,6 +261,23 @@ async def get_route(self, address: IPv4Address | IPv6Address) -> Route:
assert found_route is not None
return found_route

async def add_route(
self,
destination: IPv4Network | IPv6Network | None = None,
gateway: IPv4Address | IPv6Address | None = None,
oif: int | None = None,
family: int | None = None,
) -> None:
route_type_ = route_type()
request = route_type_.rtm_add(
destination=destination,
gateway=gateway,
oif=oif,
family=family,
)
async for _ in self._send_request(request):
pass

async def get_rules(self) -> AsyncIterator[Rule]:
rule_type_ = rule_type()
request = rule_type_.rtm_get()
Expand Down
20 changes: 20 additions & 0 deletions src/aiortnetlink/constants/rtprot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
This file was generated by gen_constants.py
"""

from enum import IntEnum
from typing import Final

__all__ = ["RTProt"]


class RTProt(IntEnum):
UNSPEC: Final = 0
REDIRECT: Final = 1
KERNEL: Final = 2
BOOT: Final = 3
STATIC: Final = 4

@property
def constant_name(self) -> str:
return f"RTPROT_{self.name}"
20 changes: 20 additions & 0 deletions src/aiortnetlink/constants/rtscope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
This file was generated by gen_constants.py
"""

from enum import IntEnum
from typing import Final

__all__ = ["RTScope"]


class RTScope(IntEnum):
UNIVERSE: Final = 0
SITE: Final = 200
LINK: Final = 253
HOST: Final = 254
NOWHERE: Final = 255

@property
def constant_name(self) -> str:
return f"RT_SCOPE_{self.name}"
19 changes: 19 additions & 0 deletions src/aiortnetlink/constants/rttable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
This file was generated by gen_constants.py
"""

from enum import IntEnum
from typing import Final

__all__ = ["RTTable"]


class RTTable(IntEnum):
UNSPEC: Final = 0
DEFAULT: Final = 253
MAIN: Final = 254
LOCAL: Final = 255

@property
def constant_name(self) -> str:
return f"RT_TABLE_{self.name}"
68 changes: 67 additions & 1 deletion src/aiortnetlink/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import socket
import struct
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
from typing import Callable, Literal, NamedTuple

from aiortnetlink.constants.icmpv6routerpref import ICMPv6RouterPref
Expand All @@ -12,6 +12,9 @@
from aiortnetlink.constants.rtmflag import RTMFlag
from aiortnetlink.constants.rtmtype import RTMType
from aiortnetlink.constants.rtntype import RTNType
from aiortnetlink.constants.rtprot import RTProt
from aiortnetlink.constants.rtscope import RTScope
from aiortnetlink.constants.rttable import RTTable
from aiortnetlink.netlink import NetlinkRequest, NLAttr, NLMsg
from aiortnetlink.structs.ifa_cacheinfo import IFACacheInfo

Expand Down Expand Up @@ -75,6 +78,55 @@ def get_route_request(
return NetlinkRequest(RTMType.GETROUTE, flags, data, RTMType.NEWROUTE)


def add_route_request(
destination: IPv4Network | IPv6Network | None = None,
gateway: IPv4Address | IPv6Address | None = None,
oif: int | None = None,
family: int | None = None,
table: int = 254,
) -> NetlinkRequest:
print(f"{destination=} {gateway=} {oif=} {family=}")
flags: int = NLFlag.REQUEST | NLFlag.CREATE | NLFlag.ACK
if destination is not None:
if family is None:
family = socket.AF_INET if destination.version == 4 else socket.AF_INET6
dst_len = destination.max_prefixlen
else:
dst_len = 0

if family is None:
raise ValueError("Route must specify an address family")

if table < 256:
rtm_table = table
else:
rtm_table = RTTable.MAIN

parts = [
RTMsg(
family=family.value,
rtm_type=RTNType.UNICAST,
protocol=RTProt.BOOT,
scope=RTScope.LINK,
dst_len=dst_len,
table=rtm_table,
).encode(),
NLAttr.from_int(RTAType.TABLE, table),
]

if destination is not None:
parts.append(NLAttr.from_ipaddress(RTAType.DST, destination.network_address))

if oif is not None:
parts.append(NLAttr.from_int(RTAType.OIF, oif))

if gateway is not None:
parts.append(NLAttr.from_ipaddress(RTAType.GATEWAY, gateway))

data = b"".join(parts)
return NetlinkRequest(RTMType.NEWROUTE, flags, data, RTMType.NEWROUTE)


@dataclass(slots=True)
class Route:
family: int
Expand Down Expand Up @@ -186,6 +238,20 @@ def rtm_get(
) -> NetlinkRequest:
return get_route_request(address)

@staticmethod
def rtm_add(
destination: IPv4Network | IPv6Network | None = None,
gateway: IPv4Address | IPv6Address | None = None,
oif: int | None = None,
family: int | None = None,
):
return add_route_request(
destination=destination,
gateway=gateway,
oif=oif,
family=family,
)

def friendly_str(
self,
show_table: bool = True,
Expand Down
20 changes: 20 additions & 0 deletions tools/gen_constants.c

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions tools/gen_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,29 @@
"RTA_NH_ID",
]

rt_prots = [
"RTPROT_UNSPEC",
"RTPROT_REDIRECT",
"RTPROT_KERNEL",
"RTPROT_BOOT",
"RTPROT_STATIC",
]

rt_scopes = [
"RT_SCOPE_UNIVERSE",
"RT_SCOPE_SITE",
"RT_SCOPE_LINK",
"RT_SCOPE_HOST",
"RT_SCOPE_NOWHERE",
]

rt_tables = [
"RT_TABLE_UNSPEC",
"RT_TABLE_DEFAULT",
"RT_TABLE_MAIN",
"RT_TABLE_LOCAL",
]

arphrd_types = [
"ARPHRD_ETHER",
"ARPHRD_NONE",
Expand Down Expand Up @@ -490,6 +513,9 @@ def __post_init__(self) -> None:
includes=["<linux/in_route.h>"],
),
TypeSpec("RTAType", "RTA_", rta_types, includes=["<linux/rtnetlink.h>"]),
TypeSpec("RTProt", "RTPROT_", rt_prots, includes=["<linux/rtnetlink.h>"]),
TypeSpec("RTScope", "RT_SCOPE_", rt_scopes, includes=["<linux/rtnetlink.h>"]),
TypeSpec("RTTable", "RT_TABLE_", rt_tables, includes=["<linux/rtnetlink.h>"]),
TypeSpec("ARPHRDType", "ARPHRD_", arphrd_types, includes=["<linux/if_arp.h>"]),
TypeSpec("IFLAType", "IFLA_", ifla_types, includes=["<linux/if.h>"]),
TypeSpec("IFFlag", "IFF_", if_flags, flag=True, includes=["<linux/if.h>"]),
Expand Down

0 comments on commit d714726

Please sign in to comment.