Skip to content

Commit

Permalink
Ver1.13.0 (#20)
Browse files Browse the repository at this point in the history
* fix: remove IsEqualsAllTree debug code

* add: ZoneText, ImportRRs

* add NSEC3
  • Loading branch information
mimuret authored Dec 8, 2024
1 parent e418e54 commit ad79d65
Show file tree
Hide file tree
Showing 22 changed files with 662 additions and 240 deletions.
2 changes: 1 addition & 1 deletion ddns/ddns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewTestUpdate() *TestUpdate {

func (u *TestUpdate) GetZone(*dns.Msg) (dnsutils.ZoneInterface, error) {
buf := bytes.NewBuffer(zonefile)
zone, _ := dnsutils.NewZone("example.jp", dns.ClassINET, nil)
zone, _ := dnsutils.NewZone("example.jp.", dns.ClassINET, nil)
if err := zone.Read(buf); err != nil {
return nil, err
}
Expand Down
226 changes: 226 additions & 0 deletions nsec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package dnsutils

import (
"errors"
"fmt"
"sort"
"strings"

"github.com/miekg/dns"
)

func CreateDoE(z ZoneInterface, opt SignOption, generator Generator) error {
if generator == nil {
generator = &DefaultGenerator{}
}
switch opt.DoEMethod {
case DenialOfExistenceMethodNSEC, "":
return createNSEC(z, generator)
case DenialOfExistenceMethodNSEC3:
return createNSEC3(z, opt, generator)
}
return fmt.Errorf("not support: %s", opt.DoEMethod)
}

func createNSEC(z ZoneInterface, generator RRSetGenerator) error {
var nodes = map[string]NameNodeInterface{}
var names []string
soa, err := GetSOA(z)
if err != nil {
return ErrBadZone
}

zoneCuts, _, err := GetZoneCuts(z.GetRootNode())
if err != nil {
return ErrBadZone
}

// get next domain names
z.GetRootNode().IterateNameNode(func(nni NameNodeInterface) error {
// Blocks with no types present MUST NOT be included
if nni.RRSetLen() == 0 {
return nil
}
// A zone MUST NOT include an NSEC RR for any domain name that only holds glue records
parent, strict := zoneCuts.GetNameNode(nni.GetName())
if parent.GetName() != z.GetName() {
if !strict && parent.GetRRSet(dns.TypeNS) != nil {
return nil
}
}
nodes[nni.GetName()] = nni
names = append(names, nni.GetName())
return nil
})

SortNames(names)
for i, name := range names {
nsec := &dns.NSEC{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeNSEC,
Class: dns.ClassINET,
// The NSEC RR SHOULD have the same TTL value as the SOA minimum TTL field.
// This is in the spirit of negative caching ([RFC2308]).
Ttl: soa.Minttl,
},
TypeBitMap: []uint16{dns.TypeRRSIG, dns.TypeNSEC},
}
if i+1 < len(names) {
nsec.NextDomain = names[i+1]
} else {
nsec.NextDomain = names[0]
}
rresetMap := nodes[name].CopyRRSetMap()
for rtype := range rresetMap {
switch rtype {
case dns.TypeRRSIG:
case dns.TypeNSEC:
default:
nsec.TypeBitMap = append(nsec.TypeBitMap, rtype)
}
}
sort.SliceStable(nsec.TypeBitMap, func(i, j int) bool { return nsec.TypeBitMap[i] < nsec.TypeBitMap[j] })

set, err := generator.NewRRSet(name, soa.Minttl, dns.ClassINET, dns.TypeNSEC)
if err != nil {
return err
}
if err := set.AddRR(nsec); err != nil {
return err
}
if err := nodes[name].SetRRSet(set); err != nil {
return err
}
}
return nil
}

var (
ErrCollision = fmt.Errorf("hash collision detected")
)

func createNSEC3(z ZoneInterface, opt SignOption, generator Generator) error {
var nodes = map[string]NameNodeInterface{}
var hashCheckName = map[string]struct{}{}
var names []string
soa, err := GetSOA(z)
if err != nil {
return ErrBadZone
}

zoneCuts, _, err := GetZoneCuts(z.GetRootNode())
if err != nil {
return ErrBadZone
}
nsec3param := &dns.NSEC3PARAM{
Hdr: dns.RR_Header{
Name: soa.Hdr.Name,
Rrtype: dns.TypeNSEC3PARAM,
Class: dns.ClassINET,
Ttl: 0,
},
Hash: dns.SHA1,
Iterations: opt.GetNSEC3Iterate(),
Salt: opt.GetNSEC3Salt(),
}

nsec3ParamRRRet, err := NewRRSetFromRRWithGenerator(nsec3param, generator)
if err != nil {
return fmt.Errorf("failed to create nsec3param")
}
if err := z.GetRootNode().SetRRSet(nsec3ParamRRRet); err != nil {
return fmt.Errorf("failed to set nsec3param")
}

// get next domain names
err = z.GetRootNode().IterateNameNode(func(nni NameNodeInterface) error {
parent, static := zoneCuts.GetNameNode(nni.GetName())
if parent.GetName() != z.GetName() {
if !static && parent.GetRRSet(dns.TypeNS) != nil {
return nil
}
}
nodes[nni.GetName()] = nni
names = append(names, nni.GetName())

hashCheckName[nni.GetName()] = struct{}{}
labels := dns.SplitDomainName(nni.GetName())
if len(labels) > 0 && labels[0] != "*" {
hashCheckName["*."+nni.GetName()] = struct{}{}
}
return nil
})
if err != nil {
return fmt.Errorf("failed to create name list: %w", err)
}

// collision check and make hash owner name
hashMap := map[string]string{}
hashCheck := map[string]string{}
for name := range hashCheckName {
hashName := dns.HashName(name, dns.SHA1, opt.GetNSEC3Iterate(), opt.GetNSEC3Salt())
hashMap[name] = hashName
if _, exist := hashCheck[hashName]; exist {
return errors.Join(ErrCollision, fmt.Errorf("collision %s %s", hashCheck[hashName], name))
} else {
hashCheck[hashName] = name
}
}
sort.Slice(names, func(i, j int) bool {
cmp, _ := CompareName(hashMap[names[i]], hashMap[names[j]])
return cmp < 0
})
for i, name := range names {
nsec3 := &dns.NSEC3{
Hdr: dns.RR_Header{
Name: dns.CanonicalName(hashMap[name] + "." + z.GetName()),
Rrtype: dns.TypeNSEC3,
Class: dns.ClassINET,
// The NSEC RR SHOULD have the same TTL value as the SOA minimum TTL field.
// This is in the spirit of negative caching ([RFC2308]).
Ttl: soa.Minttl,
},
Hash: dns.SHA1,
Iterations: opt.GetNSEC3Iterate(),
Salt: opt.GetNSEC3Salt(),
SaltLength: uint8(len(opt.GetNSEC3Salt()) / 2),
HashLength: 20, // SHA-1
}
if i+1 < len(names) {
nsec3.NextDomain = strings.ToLower(hashMap[names[i+1]])
} else {
nsec3.NextDomain = strings.ToLower(hashMap[names[0]])
}
rresetMap := nodes[name].CopyRRSetMap()
var (
isZoneCust, haveDS bool
)
for rtype := range rresetMap {
switch rtype {
case dns.TypeRRSIG:
case dns.TypeNSEC:
case dns.TypeDS:
nsec3.TypeBitMap = append(nsec3.TypeBitMap, rtype)
haveDS = true
case dns.TypeNS:
nsec3.TypeBitMap = append(nsec3.TypeBitMap, rtype)
if z.GetName() != name {
isZoneCust = true
}
default:
nsec3.TypeBitMap = append(nsec3.TypeBitMap, rtype)
}
}
if !IsENT(nodes[name]) && (!isZoneCust || isZoneCust && haveDS) {
nsec3.TypeBitMap = append(nsec3.TypeBitMap, dns.TypeRRSIG)
}

sort.SliceStable(nsec3.TypeBitMap, func(i, j int) bool { return nsec3.TypeBitMap[i] < nsec3.TypeBitMap[j] })

if err := CreateOrReplaceRRSetFromRRs(z.GetRootNode(), []dns.RR{nsec3}, generator); err != nil {
return fmt.Errorf("failed to create NSEC3 %s cover %s : %w", nsec3.Header().Name, name, err)
}
}
return nil
}
145 changes: 145 additions & 0 deletions nsec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package dnsutils_test

import (
"bytes"
_ "embed"

"github.com/miekg/dns"
"github.com/mimuret/dnsutils"
"github.com/mimuret/dnsutils/testtool"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

//go:embed testdata/sign/example.jp.nsec3.bind
var testNsec3SignedZone []byte

var _ = Describe("Test nsec.go", func() {
var (
err error
z *dnsutils.Zone
inception = uint32(1704067200)
expiration = uint32(1893456000)
nsecSignOption = dnsutils.SignOption{
DoEMethod: dnsutils.DenialOfExistenceMethodNSEC,
Inception: &inception,
Expiration: &expiration,
}
nsec3SignOption = dnsutils.SignOption{
DoEMethod: dnsutils.DenialOfExistenceMethodNSEC3,
Inception: &inception,
Expiration: &expiration,
}
zsk *dnsutils.DNSKEY
ksk *dnsutils.DNSKEY
dnskeys []*dnsutils.DNSKEY
nsecSignedZone *dnsutils.Zone
nsec3SignedZone *dnsutils.Zone
)
BeforeEach(func() {
ksk, err = dnsutils.ReadDNSKEY(bytes.NewBuffer(testDnskeyED25519KSKPriv), bytes.NewBuffer(testDnskeyED25519KSKPub))
Expect(err).To(Succeed())
zsk, err = dnsutils.ReadDNSKEY(bytes.NewBuffer(testDnskeyED25519ZSKPriv), bytes.NewBuffer(testDnskeyED25519ZSKPub))
Expect(err).To(Succeed())
dnskeys = []*dnsutils.DNSKEY{ksk, zsk}

nsecSignedZone = &dnsutils.Zone{}
err = nsecSignedZone.Read(bytes.NewBuffer(testNsecSignedZone))
Expect(err).To(Succeed())

nsec3SignedZone = &dnsutils.Zone{}
err = nsec3SignedZone.Read(bytes.NewBuffer(testNsec3SignedZone))
Expect(err).To(Succeed())
})
It("can read ED25519 zsk/ksk", func() {
Expect(ksk.GetRR().KeyTag()).To(Equal(uint16(2290)))
Expect(ksk.GetSigner().Public())
Expect(zsk.GetRR().KeyTag()).To(Equal(uint16(30075)))
})
Context("CreateDoE", func() {
When("NSEC", func() {
BeforeEach(func() {
testZoneNormalBuf := bytes.NewBuffer(testSignZone)
z = &dnsutils.Zone{}
err = z.Read(testZoneNormalBuf)
Expect(err).To(Succeed())
err = dnsutils.CreateDoE(z, nsecSignOption, nil)
})
It("return success", func() {
Expect(err).To(Succeed())
var nsecRRs []dns.RR
z.GetRootNode().IterateNameNode(func(nni dnsutils.NameNodeInterface) error {
if nsecRRSet := nni.GetRRSet(dns.TypeNSEC); nsecRRSet != nil {
nsecRRs = append(nsecRRs, nsecRRSet.GetRRs()...)
}
return nil
})
Expect(nsecRRs[0]).To(Equal(testtool.MustNewRR("example.jp. 300 IN NSEC \\000.example.jp. NS SOA RRSIG NSEC")))
Expect(nsecRRs[1]).To(Equal(testtool.MustNewRR("\\000.example.jp. 300 IN NSEC *.example.jp. TXT RRSIG NSEC")))
Expect(nsecRRs[2]).To(Equal(testtool.MustNewRR("*.example.jp. 300 IN NSEC test.hoge.example.jp. A RRSIG NSEC")))
Expect(nsecRRs[3]).To(Equal(testtool.MustNewRR("test.hoge.example.jp. 300 IN NSEC www.hoge.example.jp. A RRSIG NSEC")))
Expect(nsecRRs[4]).To(Equal(testtool.MustNewRR("www.hoge.example.jp. 300 IN NSEC sub1.example.jp. CNAME RRSIG NSEC")))
Expect(nsecRRs[5]).To(Equal(testtool.MustNewRR("sub1.example.jp. 300 IN NSEC sub2.example.jp. NS DS RRSIG NSEC")))
Expect(nsecRRs[6]).To(Equal(testtool.MustNewRR("sub2.example.jp. 300 IN NSEC example.jp. NS RRSIG NSEC")))
})
Context("Test for Sign with NSEC", func() {
BeforeEach(func() {
z = &dnsutils.Zone{}
err = z.Read(bytes.NewBuffer(testSignZone))
Expect(err).To(Succeed())
err = dnsutils.AddDNSKEY(z, dnskeys, uint32(0), nil)
Expect(err).To(Succeed())
err = dnsutils.CreateDoE(z, nsecSignOption, nil)
Expect(err).To(Succeed())
err = dnsutils.SignZone(z, nsecSignOption, dnskeys, nil)
})
It("return success", func() {
Expect(err).To(Succeed())
Expect(dnsutils.IsEqualsAllTree(z.GetRootNode(), nsecSignedZone.GetRootNode(), false)).To(BeTrue())
})
})
})
When("NSEC3", func() {
BeforeEach(func() {
testZoneNormalBuf := bytes.NewBuffer(testSignZone)
z = &dnsutils.Zone{}
err = z.Read(testZoneNormalBuf)
Expect(err).To(Succeed())
err = dnsutils.CreateDoE(z, nsec3SignOption, nil)
})
It("return success", func() {
Expect(err).To(Succeed())
var nsec3RRs []dns.RR
var nsec3params []dns.RR
z.GetRootNode().IterateNameNode(func(nni dnsutils.NameNodeInterface) error {
if nsec3RRSet := nni.GetRRSet(dns.TypeNSEC3); nsec3RRSet != nil {
nsec3RRs = append(nsec3RRs, nsec3RRSet.GetRRs()...)
}
return nil
})
if nsec3paramRRSet := z.GetRootNode().GetRRSet(dns.TypeNSEC3PARAM); nsec3paramRRSet != nil {
nsec3params = nsec3paramRRSet.GetRRs()
}
Expect(nsec3params).To(HaveLen(1))
Expect(nsec3params[0]).To(Equal(testtool.MustNewRR("example.jp. 0 IN NSEC3PARAM 1 0 0 -")))
Expect(nsec3RRs).To(HaveLen(8))
})
Context("Test for Sign with NSEC3", func() {
BeforeEach(func() {
z = &dnsutils.Zone{}
err = z.Read(bytes.NewBuffer(testSignZone))
Expect(err).To(Succeed())
err = dnsutils.AddDNSKEY(z, dnskeys, uint32(0), nil)
Expect(err).To(Succeed())
err = dnsutils.CreateDoE(z, nsec3SignOption, nil)
Expect(err).To(Succeed())
err = dnsutils.SignZone(z, nsec3SignOption, dnskeys, nil)
})
It("return success", func() {
Expect(err).To(Succeed())
Expect(dnsutils.IsEqualsAllTree(z.GetRootNode(), nsec3SignedZone.GetRootNode(), false)).To(BeTrue())
})
})
})
})
})
Loading

0 comments on commit ad79d65

Please sign in to comment.