Skip to content

Commit e19f48c

Browse files
committed
fix getting subnet cidr by protocol (#4844)
Signed-off-by: zhangzujian <zhangzujian.7@gmail.com>
1 parent 5e8288b commit e19f48c

File tree

3 files changed

+77
-16
lines changed

3 files changed

+77
-16
lines changed

pkg/daemon/gateway.go

+8-13
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (c *Controller) getSubnetsNeedNAT(protocol string) ([]string, error) {
106106
for _, subnet := range subnets {
107107
if c.isSubnetNeedNat(subnet, protocol) {
108108
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
109-
if err == nil {
109+
if err == nil && cidrBlock != "" {
110110
subnetsNeedNat = append(subnetsNeedNat, cidrBlock)
111111
}
112112
}
@@ -146,7 +146,7 @@ func (c *Controller) getSubnetsDistributedGateway(protocol string) ([]string, er
146146
subnet.Spec.GatewayType == kubeovnv1.GWDistributedType &&
147147
(subnet.Spec.Protocol == kubeovnv1.ProtocolDual || subnet.Spec.Protocol == protocol) {
148148
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
149-
if err == nil {
149+
if err == nil && cidrBlock != "" {
150150
result = append(result, cidrBlock)
151151
}
152152
}
@@ -177,7 +177,7 @@ func (c *Controller) getDefaultVpcSubnetsCIDR(protocol string) ([]string, map[st
177177
for _, subnet := range subnets {
178178
if subnet.Spec.Vpc == c.config.ClusterRouter && (subnet.Spec.Vlan == "" || subnet.Spec.LogicalGateway) && subnet.Spec.CIDRBlock != "" {
179179
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
180-
if err == nil {
180+
if err == nil && cidrBlock != "" {
181181
ret = append(ret, cidrBlock)
182182
subnetMap[subnet.Name] = cidrBlock
183183
}
@@ -209,22 +209,17 @@ func (c *Controller) getOtherNodes(protocol string) ([]string, error) {
209209
}
210210

211211
func getCidrByProtocol(cidr, protocol string) (string, error) {
212-
var cidrStr string
213212
if err := util.CheckCidrs(cidr); err != nil {
214213
return "", err
215214
}
216215

217-
if util.CheckProtocol(cidr) == kubeovnv1.ProtocolDual {
218-
cidrBlocks := strings.Split(cidr, ",")
219-
if protocol == kubeovnv1.ProtocolIPv4 {
220-
cidrStr = cidrBlocks[0]
221-
} else if protocol == kubeovnv1.ProtocolIPv6 {
222-
cidrStr = cidrBlocks[1]
216+
for _, cidr := range strings.Split(cidr, ",") {
217+
if util.CheckProtocol(cidr) == protocol {
218+
return cidr, nil
223219
}
224-
} else {
225-
cidrStr = cidr
226220
}
227-
return cidrStr, nil
221+
222+
return "", nil
228223
}
229224

230225
func (c *Controller) getEgressNatIPByNode(nodeName string) (map[string]string, error) {

pkg/daemon/gateway_linux.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,17 @@ func (c *Controller) reconcileNatOutGoingPolicyIPset(protocol string) {
204204
return
205205
}
206206

207-
subnetCidrs := make([]string, 0)
207+
subnetCidrs := make([]string, 0, len(subnets))
208208
natPolicyRuleIDs := strset.New()
209209
for _, subnet := range subnets {
210210
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
211211
if err != nil {
212212
klog.Errorf("failed to get subnet %s CIDR block by protocol: %v", subnet.Name, err)
213213
continue
214214
}
215-
subnetCidrs = append(subnetCidrs, cidrBlock)
215+
if cidrBlock != "" {
216+
subnetCidrs = append(subnetCidrs, cidrBlock)
217+
}
216218
for _, rule := range subnet.Status.NatOutgoingPolicyRules {
217219
if rule.RuleID == "" {
218220
klog.Errorf("unexpected empty ID for NAT outgoing rule %q of subnet %s", rule.NatOutgoingPolicyRule, subnet.Name)
@@ -965,6 +967,9 @@ func (c *Controller) generateNatOutgoingPolicyChainRules(protocol string) ([]uti
965967
klog.Errorf("failed to get subnet %s cidr block with protocol: %v", subnet.Name, err)
966968
continue
967969
}
970+
if cidrBlock == "" {
971+
continue
972+
}
968973

969974
ovnNatPolicySubnetChainName := OvnNatOutGoingPolicySubnet + util.GetTruncatedUID(string(subnet.GetUID()))
970975
natPolicySubnetIptables = append(natPolicySubnetIptables, util.IPTableRule{Table: NAT, Chain: OvnNatOutGoingPolicy, Rule: strings.Fields(fmt.Sprintf(`-s %s -m comment --comment natPolicySubnet-%s -j %s`, cidrBlock, subnet.Name, ovnNatPolicySubnetChainName))})
@@ -1553,7 +1558,7 @@ func (c *Controller) getSubnetsNeedPR(protocol string) (map[policyRouteMeta]stri
15531558
}
15541559
if meta.gateway != "" {
15551560
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
1556-
if err == nil {
1561+
if err == nil && cidrBlock != "" {
15571562
subnetsNeedPR[meta] = cidrBlock
15581563
}
15591564
}

pkg/daemon/gateway_test.go

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package daemon
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
8+
kubeovnv1 "github.com/kubeovn/kube-ovn/pkg/apis/kubeovn/v1"
9+
)
10+
11+
func TestGetCidrByProtocol(t *testing.T) {
12+
cases := []struct {
13+
name string
14+
cidr string
15+
protocol string
16+
wantErr bool
17+
expetced string
18+
}{{
19+
name: "ipv4 only",
20+
cidr: "1.1.1.0/24",
21+
protocol: kubeovnv1.ProtocolIPv4,
22+
expetced: "1.1.1.0/24",
23+
}, {
24+
name: "ipv6 only",
25+
cidr: "2001:db8::/120",
26+
protocol: kubeovnv1.ProtocolIPv6,
27+
expetced: "2001:db8::/120",
28+
}, {
29+
name: "get ipv4 from ipv6",
30+
cidr: "2001:db8::/120",
31+
protocol: kubeovnv1.ProtocolIPv4,
32+
}, {
33+
name: "get ipv4 from dual stack",
34+
cidr: "1.1.1.0/24,2001:db8::/120",
35+
protocol: kubeovnv1.ProtocolIPv4,
36+
expetced: "1.1.1.0/24",
37+
}, {
38+
name: "get ipv6 from ipv4",
39+
cidr: "1.1.1.0/24",
40+
protocol: kubeovnv1.ProtocolIPv6,
41+
}, {
42+
name: "get ipv6 from dual stack",
43+
cidr: "1.1.1.0/24,2001:db8::/120",
44+
protocol: kubeovnv1.ProtocolIPv6,
45+
expetced: "2001:db8::/120",
46+
}, {
47+
name: "invalid cidr",
48+
cidr: "foo bar",
49+
protocol: kubeovnv1.ProtocolIPv4,
50+
wantErr: true,
51+
}}
52+
for _, c := range cases {
53+
t.Run(c.name, func(t *testing.T) {
54+
got, err := getCidrByProtocol(c.cidr, c.protocol)
55+
if (err != nil) != c.wantErr {
56+
t.Errorf("getCidrByProtocol(%q, %q) error = %v, wantErr = %v", c.cidr, c.protocol, err, c.wantErr)
57+
}
58+
require.Equal(t, c.expetced, got)
59+
})
60+
}
61+
}

0 commit comments

Comments
 (0)