commit 49a0f1d403aee48c1bc1c26e728e448cf975d363 Author: restitux Date: Thu Jul 10 00:44:54 2025 -0600 initial commit diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1249731 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2016-2020 The CoreDNS authors and contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..ad21e77 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# host_specific_cache + +This is a fork of the built in coredns cache plugin that is host aware. This means that the cache is not shared between different requesting IPs. This is to enable compatibility with the DNS view feature in . diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..621518d --- /dev/null +++ b/cache.go @@ -0,0 +1,325 @@ +// Package cache implements a cache. +package cache + +import ( + "hash/fnv" + "net" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/cache" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Cache is a plugin that looks up responses in a cache and caches replies. +// It has a success and a denial of existence cache. +type Cache struct { + Next plugin.Handler + Zones []string + + zonesMetricLabel string + viewMetricLabel string + + ncache *cache.Cache + ncap int + nttl time.Duration + minnttl time.Duration + + pcache *cache.Cache + pcap int + pttl time.Duration + minpttl time.Duration + failttl time.Duration // TTL for caching SERVFAIL responses + + // Prefetch. + prefetch int + duration time.Duration + percentage int + + // Stale serve + staleUpTo time.Duration + verifyStale bool + + // Positive/negative zone exceptions + pexcept []string + nexcept []string + + // Keep ttl option + keepttl bool + + // Testing. + now func() time.Time +} + +// New returns an initialized Cache with default settings. It's up to the +// caller to set the Next handler. +func New() *Cache { + return &Cache{ + Zones: []string{"."}, + pcap: defaultCap, + pcache: cache.New(defaultCap), + pttl: maxTTL, + minpttl: minTTL, + ncap: defaultCap, + ncache: cache.New(defaultCap), + nttl: maxNTTL, + minnttl: minNTTL, + failttl: minNTTL, + prefetch: 0, + duration: 1 * time.Minute, + percentage: 10, + now: time.Now, + } +} + +// key returns key under which we store the item, -1 will be returned if we don't store the message. +// Currently we do not cache Truncated, errors zone transfers or dynamic update messages. +// qname holds the already lowercased qname. +func key(qname string, remoteIP string, m *dns.Msg, t response.Type, do, cd bool) (bool, uint64) { + // We don't store truncated responses. + if m.Truncated { + return false, 0 + } + // Nor errors or Meta or Update. + if t == response.OtherError || t == response.Meta || t == response.Update { + return false, 0 + } + + return true, hash(qname, remoteIP, m.Question[0].Qtype, do, cd) +} + +var ( + one = []byte("1") + zero = []byte("0") +) + +func hash(qname string, remoteIP string, qtype uint16, do, cd bool) uint64 { + h := fnv.New64() + + if do { + h.Write(one) + } else { + h.Write(zero) + } + + if cd { + h.Write(one) + } else { + h.Write(zero) + } + + h.Write([]byte{byte(qtype >> 8)}) + h.Write([]byte{byte(qtype)}) + h.Write([]byte(remoteIP)) + h.Write([]byte(qname)) + return h.Sum64() +} + +func computeTTL(msgTTL, minTTL, maxTTL time.Duration) time.Duration { + ttl := msgTTL + if ttl < minTTL { + ttl = minTTL + } + if ttl > maxTTL { + ttl = maxTTL + } + return ttl +} + +// ResponseWriter is a response writer that caches the reply message. +type ResponseWriter struct { + dns.ResponseWriter + *Cache + state request.Request + server string // Server handling the request. + + do bool // When true the original request had the DO bit set. + cd bool // When true the original request had the CD bit set. + ad bool // When true the original request had the AD bit set. + prefetch bool // When true write nothing back to the client. + remoteAddr net.Addr + + wildcardFunc func() string // function to retrieve wildcard name that synthesized the result. + + pexcept []string // positive zone exceptions + nexcept []string // negative zone exceptions +} + +// newPrefetchResponseWriter returns a Cache ResponseWriter to be used in +// prefetch requests. It ensures RemoteAddr() can be called even after the +// original connection has already been closed. +func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter { + // Resolve the address now, the connection might be already closed when the + // actual prefetch request is made. + addr := state.W.RemoteAddr() + // The protocol of the client triggering a cache prefetch doesn't matter. + // The address type is used by request.Proto to determine the response size, + // and using TCP ensures the message isn't unnecessarily truncated. + if u, ok := addr.(*net.UDPAddr); ok { + addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone} + } + + return &ResponseWriter{ + ResponseWriter: state.W, + Cache: c, + state: state, + server: server, + do: state.Do(), + cd: state.Req.CheckingDisabled, + prefetch: true, + remoteAddr: addr, + } +} + +// RemoteAddr implements the dns.ResponseWriter interface. +func (w *ResponseWriter) RemoteAddr() net.Addr { + if w.remoteAddr != nil { + return w.remoteAddr + } + return w.ResponseWriter.RemoteAddr() +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { + res = res.Copy() + mt, _ := response.Typify(res, w.now().UTC()) + + // key returns empty string for anything we don't want to cache. + hasKey, key := key(w.state.Name(), w.state.IP(), res, mt, w.do, w.cd) + + msgTTL := dnsutil.MinimalTTL(res, mt) + var duration time.Duration + switch mt { + case response.NameError, response.NoData: + duration = computeTTL(msgTTL, w.minnttl, w.nttl) + case response.ServerError: + duration = w.failttl + default: + duration = computeTTL(msgTTL, w.minpttl, w.pttl) + } + + // Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.) + ttl := uint32(duration.Seconds()) + res.Answer = filterRRSlice(res.Answer, ttl, false) + res.Ns = filterRRSlice(res.Ns, ttl, false) + res.Extra = filterRRSlice(res.Extra, ttl, false) + + if !w.do && !w.ad { + // unset AD bit if requester is not OK with DNSSEC + // But retain AD bit if requester set the AD bit in the request, per RFC6840 5.7-5.8 + res.AuthenticatedData = false + } + + if hasKey && duration > 0 { + if w.state.Match(res) { + w.set(res, key, mt, duration) + cacheSize.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.pcache.Len())) + cacheSize.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.ncache.Len())) + } else { + // Don't log it, but increment counter + cacheDrops.WithLabelValues(w.server, w.zonesMetricLabel, w.viewMetricLabel).Inc() + } + } + + if w.prefetch { + return nil + } + + return w.ResponseWriter.WriteMsg(res) +} + +func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) { + // duration is expected > 0 + // and key is valid + switch mt { + case response.NoError, response.Delegation: + if plugin.Zones(w.pexcept).Matches(m.Question[0].Name) != "" { + // zone is in exception list, do not cache + return + } + i := newItem(m, w.now(), duration) + if w.wildcardFunc != nil { + i.wildcard = w.wildcardFunc() + } + if w.pcache.Add(key, i) { + evictions.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Inc() + } + // when pre-fetching, remove the negative cache entry if it exists + if w.prefetch { + w.ncache.Remove(key) + } + + case response.NameError, response.NoData, response.ServerError: + if plugin.Zones(w.nexcept).Matches(m.Question[0].Name) != "" { + // zone is in exception list, do not cache + return + } + i := newItem(m, w.now(), duration) + if w.wildcardFunc != nil { + i.wildcard = w.wildcardFunc() + } + if w.ncache.Add(key, i) { + evictions.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Inc() + } + + case response.OtherError: + // don't cache these + default: + log.Warningf("Caching called with unknown classification: %d", mt) + } +} + +// Write implements the dns.ResponseWriter interface. +func (w *ResponseWriter) Write(buf []byte) (int, error) { + log.Warning("Caching called with Write: not caching reply") + if w.prefetch { + return 0, nil + } + n, err := w.ResponseWriter.Write(buf) + return n, err +} + +// verifyStaleResponseWriter is a response writer that only writes messages if they should replace a +// stale cache entry, and otherwise discards them. +type verifyStaleResponseWriter struct { + *ResponseWriter + refreshed bool // set to true if the last WriteMsg wrote to ResponseWriter, false otherwise. +} + +// newVerifyStaleResponseWriter returns a ResponseWriter to be used when verifying stale cache +// entries. It only forward writes if an entry was successfully refreshed according to RFC8767, +// section 4 (response is NoError or NXDomain), and ignores any other response. +func newVerifyStaleResponseWriter(w *ResponseWriter) *verifyStaleResponseWriter { + return &verifyStaleResponseWriter{ + w, + false, + } +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (w *verifyStaleResponseWriter) WriteMsg(res *dns.Msg) error { + w.refreshed = false + if res.Rcode == dns.RcodeSuccess || res.Rcode == dns.RcodeNameError { + w.refreshed = true + return w.ResponseWriter.WriteMsg(res) // stores to the cache and send to client + } + return nil // else discard +} + +const ( + maxTTL = dnsutil.MaximumDefaulTTL + minTTL = dnsutil.MinimalDefaultTTL + maxNTTL = dnsutil.MaximumDefaulTTL / 2 + minNTTL = dnsutil.MinimalDefaultTTL + + defaultCap = 10000 // default capacity of the cache. + + // Success is the class for caching positive caching. + Success = "success" + // Denial is the class defined for negative caching. + Denial = "denial" +) diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..fd3dc1a --- /dev/null +++ b/cache_test.go @@ -0,0 +1,903 @@ +package cache + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func cacheMsg(m *dns.Msg, tc test.Case) *dns.Msg { + m.RecursionAvailable = tc.RecursionAvailable + m.AuthenticatedData = tc.AuthenticatedData + m.CheckingDisabled = tc.CheckingDisabled + m.Authoritative = tc.Authoritative + m.Rcode = tc.Rcode + m.Truncated = tc.Truncated + m.Answer = tc.Answer + m.Ns = tc.Ns + // m.Extra = tc.in.Extra don't copy Extra, because we don't care and fake EDNS0 DO with tc.Do. + return m +} + +func newTestCache(ttl time.Duration) (*Cache, *ResponseWriter) { + c := New() + c.pttl = ttl + c.nttl = ttl + + crr := &ResponseWriter{ResponseWriter: nil, Cache: c} + crr.nexcept = []string{"neg-disabled.example.org."} + crr.pexcept = []string{"pos-disabled.example.org."} + + return c, crr +} + +// TestCacheInsertion verifies the insertion of items to the cache. +func TestCacheInsertion(t *testing.T) { + cacheTestCases := []struct { + name string + out test.Case // the expected message coming "out" of cache + in test.Case // the test message going "in" to cache + shouldCache bool + }{ + { + name: "test ad bit cache", + out: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3600 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + in: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3601 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3601 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + shouldCache: true, + }, + { + name: "test case sensitivity cache", + out: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3600 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + in: test.Case{ + Qname: "mIEK.nL.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3601 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3601 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + shouldCache: true, + }, + { + name: "test truncated responses shouldn't cache", + in: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com.")}, + Truncated: true, + }, + shouldCache: false, + }, + { + name: "test dns.RcodeNameError cache", + out: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + RecursionAvailable: true, + }, + in: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test dns.RcodeServerFailure cache", + out: test.Case{ + Rcode: dns.RcodeServerFailure, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + in: test.Case{ + Rcode: dns.RcodeServerFailure, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test dns.RcodeNotImplemented cache", + out: test.Case{ + Rcode: dns.RcodeNotImplemented, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + in: test.Case{ + Rcode: dns.RcodeNotImplemented, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test expired RRSIG doesn't cache", + in: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Do: true, + Answer: []dns.RR{ + test.MX("miek.nl. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3600 IN MX 10 aspmx2.googlemail.com."), + test.RRSIG("miek.nl. 1800 IN RRSIG MX 8 2 1800 20160521031301 20160421031301 12051 miek.nl. lAaEzB5teQLLKyDenatmyhca7blLRg9DoGNrhe3NReBZN5C5/pMQk8Jc u25hv2fW23/SLm5IC2zaDpp2Fzgm6Jf7e90/yLcwQPuE7JjS55WMF+HE LEh7Z6AEb+Iq4BWmNhUz6gPxD4d9eRMs7EAzk13o1NYi5/JhfL6IlaYy qkc="), + }, + RecursionAvailable: true, + }, + shouldCache: false, + }, + { + name: "test DO bit with RRSIG not expired cache", + out: test.Case{ + Qname: "example.org.", Qtype: dns.TypeMX, + Do: true, + Answer: []dns.RR{ + test.MX("example.org. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("example.org. 3600 IN MX 10 aspmx2.googlemail.com."), + test.RRSIG("example.org. 3600 IN RRSIG MX 8 2 1800 20170521031301 20170421031301 12051 miek.nl. lAaEzB5teQLLKyDenatmyhca7blLRg9DoGNrhe3NReBZN5C5/pMQk8Jc u25hv2fW23/SLm5IC2zaDpp2Fzgm6Jf7e90/yLcwQPuE7JjS55WMF+HE LEh7Z6AEb+Iq4BWmNhUz6gPxD4d9eRMs7EAzk13o1NYi5/JhfL6IlaYy qkc="), + }, + RecursionAvailable: true, + }, + in: test.Case{ + Qname: "example.org.", Qtype: dns.TypeMX, + Do: true, + Answer: []dns.RR{ + test.MX("example.org. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("example.org. 3600 IN MX 10 aspmx2.googlemail.com."), + test.RRSIG("example.org. 1800 IN RRSIG MX 8 2 1800 20170521031301 20170421031301 12051 miek.nl. lAaEzB5teQLLKyDenatmyhca7blLRg9DoGNrhe3NReBZN5C5/pMQk8Jc u25hv2fW23/SLm5IC2zaDpp2Fzgm6Jf7e90/yLcwQPuE7JjS55WMF+HE LEh7Z6AEb+Iq4BWmNhUz6gPxD4d9eRMs7EAzk13o1NYi5/JhfL6IlaYy qkc="), + }, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test CD bit cache", + out: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "dnssec-failed.org.", + Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("dnssec-failed.org. 3600 IN A 127.0.0.1"), + }, + CheckingDisabled: true, + }, + in: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "dnssec-failed.org.", + Answer: []dns.RR{ + test.A("dnssec-failed.org. 3600 IN A 127.0.0.1"), + }, + Qtype: dns.TypeA, + CheckingDisabled: true, + }, + shouldCache: true, + }, + { + name: "test negative zone exception shouldn't cache", + in: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "neg-disabled.example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + }, + shouldCache: false, + }, + { + name: "test positive zone exception shouldn't cache", + in: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "pos-disabled.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("pos-disabled.example.org. 3600 IN A 127.0.0.1"), + }, + }, + shouldCache: false, + }, + { + name: "test positive zone exception with negative answer cache", + in: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "pos-disabled.example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + }, + out: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "pos-disabled.example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + }, + shouldCache: true, + }, + { + name: "test negative zone exception with positive answer cache", + in: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "neg-disabled.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("neg-disabled.example.org. 3600 IN A 127.0.0.1"), + }, + }, + out: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "neg-disabled.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("neg-disabled.example.org. 3600 IN A 127.0.0.1"), + }, + }, + shouldCache: true, + }, + } + now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017") + utc := now.UTC() + + for _, tc := range cacheTestCases { + t.Run(tc.name, func(t *testing.T) { + // Create a new cache every time to prevent accidental comparison with a previous item. + c, crr := newTestCache(maxTTL) + + m := tc.in.Msg() + m = cacheMsg(m, tc.in) + + state := request.Request{W: &test.ResponseWriter{}, Req: m} + + mt, _ := response.Typify(m, utc) + valid, k := key(state.Name(), m, mt, state.Do(), state.Req.CheckingDisabled) + + if valid { + // Insert cache entry + crr.set(m, k, mt, c.pttl) + } + + // Attempt to retrieve cache entry + i := c.getIgnoreTTL(time.Now().UTC(), state, "dns://:53") + found := i != nil + + if !tc.shouldCache && found { + t.Fatalf("Cached message that should not have been cached: %s", state.Name()) + } + if tc.shouldCache && !found { + t.Fatalf("Did not cache message that should have been cached: %s", state.Name()) + } + + if found { + resp := i.toMsg(m, time.Now().UTC(), state.Do(), m.AuthenticatedData) + + // TODO: If we incorporate these individual checks into the + // test.Header function, we can eliminate them from here. + // Cache entries are always Authoritative. + if resp.Authoritative != true { + t.Error("Expected Authoritative Answer bit to be true, but was false") + } + if resp.AuthenticatedData != tc.out.AuthenticatedData { + t.Errorf("Expected Authenticated Data bit to be %t, but got %t", tc.out.AuthenticatedData, resp.AuthenticatedData) + } + if resp.RecursionAvailable != tc.out.RecursionAvailable { + t.Errorf("Expected Recursion Available bit to be %t, but got %t", tc.out.RecursionAvailable, resp.RecursionAvailable) + } + if resp.CheckingDisabled != tc.out.CheckingDisabled { + t.Errorf("Expected Checking Disabled bit to be %t, but got %t", tc.out.CheckingDisabled, resp.CheckingDisabled) + } + + if err := test.Header(tc.out, resp); err != nil { + t.Logf("Cache %v", resp) + t.Error(err) + } + if err := test.Section(tc.out, test.Answer, resp.Answer); err != nil { + t.Logf("Cache %v -- %v", test.Answer, resp.Answer) + t.Error(err) + } + if err := test.Section(tc.out, test.Ns, resp.Ns); err != nil { + t.Error(err) + } + if err := test.Section(tc.out, test.Extra, resp.Extra); err != nil { + t.Error(err) + } + } + }) + } +} + +func TestCacheZeroTTL(t *testing.T) { + c := New() + c.minpttl = 0 + c.minnttl = 0 + c.Next = ttlBackend(0) + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + ctx := context.TODO() + + c.ServeDNS(ctx, &test.ResponseWriter{}, req) + if c.pcache.Len() != 0 { + t.Errorf("Msg with 0 TTL should not have been cached") + } + if c.ncache.Len() != 0 { + t.Errorf("Msg with 0 TTL should not have been cached") + } +} + +func TestCacheServfailTTL0(t *testing.T) { + c := New() + c.minpttl = minTTL + c.minnttl = minNTTL + c.failttl = 0 + c.Next = servFailBackend(0) + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + ctx := context.TODO() + + c.ServeDNS(ctx, &test.ResponseWriter{}, req) + if c.ncache.Len() != 0 { + t.Errorf("SERVFAIL response should not have been cached") + } +} + +func TestServeFromStaleCache(t *testing.T) { + c := New() + c.Next = ttlBackend(60) + + req := new(dns.Msg) + req.SetQuestion("cached.org.", dns.TypeA) + ctx := context.TODO() + + // Cache cached.org. with 60s TTL + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.staleUpTo = 1 * time.Hour + c.ServeDNS(ctx, rec, req) + if c.pcache.Len() != 1 { + t.Fatalf("Msg with > 0 TTL should have been cached") + } + + // No more backend resolutions, just from cache if available. + c.Next = plugin.HandlerFunc(func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) { + return 255, nil // Below, a 255 means we tried querying upstream. + }) + + tests := []struct { + name string + futureMinutes int + expectedResult int + }{ + {"cached.org.", 30, 0}, + {"cached.org.", 60, 0}, + {"cached.org.", 70, 255}, + + {"notcached.org.", 30, 255}, + {"notcached.org.", 60, 255}, + {"notcached.org.", 70, 255}, + } + + for i, tt := range tests { + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.now = func() time.Time { return time.Now().Add(time.Duration(tt.futureMinutes) * time.Minute) } + r := req.Copy() + r.SetQuestion(tt.name, dns.TypeA) + if ret, _ := c.ServeDNS(ctx, rec, r); ret != tt.expectedResult { + t.Errorf("Test %d: expecting %v; got %v", i, tt.expectedResult, ret) + } + } +} + +func TestServeFromStaleCacheFetchVerify(t *testing.T) { + c := New() + c.Next = ttlBackend(120) + + req := new(dns.Msg) + req.SetQuestion("cached.org.", dns.TypeA) + ctx := context.TODO() + + // Cache cached.org. with 120s TTL + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.staleUpTo = 1 * time.Hour + c.verifyStale = true + c.ServeDNS(ctx, rec, req) + if c.pcache.Len() != 1 { + t.Fatalf("Msg with > 0 TTL should have been cached") + } + + tests := []struct { + name string + upstreamRCode int + upstreamTtl int + futureMinutes int + expectedRCode int + expectedTtl int + }{ + // After 1 minutes of initial TTL, we should see a cached response + {"cached.org.", dns.RcodeSuccess, 200, 1, dns.RcodeSuccess, 60}, // ttl = 120 - 60 -- not refreshed + + // After the 2 more minutes, we should see upstream responses because upstream is available + {"cached.org.", dns.RcodeSuccess, 200, 3, dns.RcodeSuccess, 200}, + + // After the TTL expired, if the server fails we should get the cached entry + {"cached.org.", dns.RcodeServerFailure, 200, 7, dns.RcodeSuccess, 0}, + + // After 1 more minutes, if the server serves nxdomain we should see them (despite being within the serve stale period) + {"cached.org.", dns.RcodeNameError, 150, 8, dns.RcodeNameError, 150}, + } + + for i, tt := range tests { + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.now = func() time.Time { return time.Now().Add(time.Duration(tt.futureMinutes) * time.Minute) } + + switch tt.upstreamRCode { + case dns.RcodeSuccess: + c.Next = ttlBackend(tt.upstreamTtl) + case dns.RcodeServerFailure: + // Make upstream fail, should now rely on cache during the c.staleUpTo period + c.Next = servFailBackend(tt.upstreamTtl) + case dns.RcodeNameError: + c.Next = nxDomainBackend(tt.upstreamTtl) + default: + t.Fatal("upstream code not implemented") + } + + r := req.Copy() + r.SetQuestion(tt.name, dns.TypeA) + ret, _ := c.ServeDNS(ctx, rec, r) + if ret != tt.expectedRCode { + t.Errorf("Test %d: expected rcode=%v, got rcode=%v", i, tt.expectedRCode, ret) + continue + } + switch ret { + case dns.RcodeSuccess: + recTtl := rec.Msg.Answer[0].Header().Ttl + if tt.expectedTtl != int(recTtl) { + t.Errorf("Test %d: expected TTL=%d, got TTL=%d", i, tt.expectedTtl, recTtl) + } + case dns.RcodeNameError: + soaTtl := rec.Msg.Ns[0].Header().Ttl + if tt.expectedTtl != int(soaTtl) { + t.Errorf("Test %d: expected TTL=%d, got TTL=%d", i, tt.expectedTtl, soaTtl) + } + } + } +} + +func TestNegativeStaleMaskingPositiveCache(t *testing.T) { + c := New() + c.staleUpTo = time.Minute * 10 + c.Next = nxDomainBackend(60) + + req := new(dns.Msg) + qname := "cached.org." + req.SetQuestion(qname, dns.TypeA) + ctx := context.TODO() + + // Add an entry to Negative Cache": cached.org. = NXDOMAIN + expectedResult := dns.RcodeNameError + if ret, _ := c.ServeDNS(ctx, &test.ResponseWriter{}, req); ret != expectedResult { + t.Errorf("Test 0 Negative Cache Population: expecting %v; got %v", expectedResult, ret) + } + + // Confirm item was added to negative cache and not to positive cache + if c.ncache.Len() == 0 { + t.Errorf("Test 0 Negative Cache Population: item not added to negative cache") + } + if c.pcache.Len() != 0 { + t.Errorf("Test 0 Negative Cache Population: item added to positive cache") + } + + // Set the Backend to return non-cachable errors only + c.Next = plugin.HandlerFunc(func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) { + return 255, nil // Below, a 255 means we tried querying upstream. + }) + + // Confirm we get the NXDOMAIN from the negative cache, not the error form the backend + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req = new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + expectedResult = dns.RcodeNameError + if c.ServeDNS(ctx, rec, req); rec.Rcode != expectedResult { + t.Errorf("Test 1 NXDOMAIN from Negative Cache: expecting %v; got %v", expectedResult, rec.Rcode) + } + + // Jump into the future beyond when the negative cache item would go stale + // but before the item goes rotten (exceeds serve stale time) + c.now = func() time.Time { return time.Now().Add(time.Duration(5) * time.Minute) } + + // Set Backend to return a positive NOERROR + A record response + c.Next = BackendHandler() + + // Make a query for the stale cache item + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + req = new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + expectedResult = dns.RcodeNameError + if c.ServeDNS(ctx, rec, req); rec.Rcode != expectedResult { + t.Errorf("Test 2 NOERROR from Backend: expecting %v; got %v", expectedResult, rec.Rcode) + } + + // Confirm that prefetch removes the negative cache item. + waitFor := 3 + for i := 1; i <= waitFor; i++ { + if c.ncache.Len() != 0 { + if i == waitFor { + t.Errorf("Test 2 NOERROR from Backend: item still exists in negative cache") + } + time.Sleep(time.Second) + continue + } + } + + // Confirm that positive cache has the item + if c.pcache.Len() != 1 { + t.Errorf("Test 2 NOERROR from Backend: item missing from positive cache") + } + + // Backend - Give error only + c.Next = plugin.HandlerFunc(func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) { + return 255, nil // Below, a 255 means we tried querying upstream. + }) + + // Query again, expect that positive cache entry is not masked by a negative cache entry + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + req = new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + expectedResult = dns.RcodeSuccess + if ret, _ := c.ServeDNS(ctx, rec, req); ret != expectedResult { + t.Errorf("Test 3 NOERROR from Cache: expecting %v; got %v", expectedResult, ret) + } +} + +func BenchmarkCacheResponse(b *testing.B) { + c := New() + c.prefetch = 1 + c.Next = BackendHandler() + + ctx := context.TODO() + + reqs := make([]*dns.Msg, 5) + for i, q := range []string{"example1", "example2", "a", "b", "ddd"} { + reqs[i] = new(dns.Msg) + reqs[i].SetQuestion(q+".example.org.", dns.TypeA) + } + + b.StartTimer() + + j := 0 + for range b.N { + req := reqs[j] + c.ServeDNS(ctx, &test.ResponseWriter{}, req) + j = (j + 1) % 5 + } +} + +func BackendHandler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response = true + m.RecursionAvailable = true + + owner := m.Question[0].Name + m.Answer = []dns.RR{test.A(owner + " 303 IN A 127.0.0.53")} + + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +func nxDomainBackend(ttl int) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + + m.Ns = []dns.RR{test.SOA(fmt.Sprintf("example.org. %d IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600", ttl))} + + m.Rcode = dns.RcodeNameError + w.WriteMsg(m) + return dns.RcodeNameError, nil + }) +} + +func ttlBackend(ttl int) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + + m.Answer = []dns.RR{test.A(fmt.Sprintf("example.org. %d IN A 127.0.0.53", ttl))} + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +func servFailBackend(ttl int) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + + m.Ns = []dns.RR{test.SOA(fmt.Sprintf("example.org. %d IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600", ttl))} + + m.Rcode = dns.RcodeServerFailure + w.WriteMsg(m) + return dns.RcodeServerFailure, nil + }) +} + +func TestComputeTTL(t *testing.T) { + tests := []struct { + msgTTL time.Duration + minTTL time.Duration + maxTTL time.Duration + expectedTTL time.Duration + }{ + {1800 * time.Second, 300 * time.Second, 3600 * time.Second, 1800 * time.Second}, + {299 * time.Second, 300 * time.Second, 3600 * time.Second, 300 * time.Second}, + {299 * time.Second, 0 * time.Second, 3600 * time.Second, 299 * time.Second}, + {3601 * time.Second, 300 * time.Second, 3600 * time.Second, 3600 * time.Second}, + } + for i, test := range tests { + ttl := computeTTL(test.msgTTL, test.minTTL, test.maxTTL) + if ttl != test.expectedTTL { + t.Errorf("Test %v: Expected ttl %v but found: %v", i, test.expectedTTL, ttl) + } + } +} + +func TestCacheWildcardMetadata(t *testing.T) { + c := New() + qname := "foo.bar.example.org." + wildcard := "*.bar.example.org." + c.Next = wildcardMetadataBackend(qname, wildcard) + + req := new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + state := request.Request{W: &test.ResponseWriter{}, Req: req} + + // 1. Test writing wildcard metadata retrieved from backend to the cache + + ctx := metadata.ContextWithMetadata(context.TODO()) + w := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(ctx, w, req) + if c.pcache.Len() != 1 { + t.Errorf("Msg should have been cached") + } + _, k := key(qname, w.Msg, response.NoError, state.Do(), state.Req.CheckingDisabled) + i, _ := c.pcache.Get(k) + if i.(*item).wildcard != wildcard { + t.Errorf("expected wildcard response to enter cache with cache item's wildcard = %q, got %q", wildcard, i.(*item).wildcard) + } + + // 2. Test retrieving the cached item from cache and writing its wildcard value to metadata + + // reset context and response writer + ctx = metadata.ContextWithMetadata(context.TODO()) + w = dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(ctx, w, req) + f := metadata.ValueFunc(ctx, "zone/wildcard") + if f == nil { + t.Fatal("expected metadata func for wildcard response retrieved from cache, got nil") + } + if f() != wildcard { + t.Errorf("after retrieving wildcard item from cache, expected \"zone/wildcard\" metadata value to be %q, got %q", wildcard, i.(*item).wildcard) + } +} + +func TestCacheKeepTTL(t *testing.T) { + defaultTtl := 60 + + c := New() + c.Next = ttlBackend(defaultTtl) + + req := new(dns.Msg) + req.SetQuestion("cached.org.", dns.TypeA) + ctx := context.TODO() + + // Cache cached.org. with 60s TTL + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.keepttl = true + c.ServeDNS(ctx, rec, req) + + tests := []struct { + name string + futureSeconds int + }{ + {"cached.org.", 0}, + {"cached.org.", 30}, + {"uncached.org.", 60}, + } + + for i, tt := range tests { + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.now = func() time.Time { return time.Now().Add(time.Duration(tt.futureSeconds) * time.Second) } + r := req.Copy() + r.SetQuestion(tt.name, dns.TypeA) + c.ServeDNS(ctx, rec, r) + + recTtl := rec.Msg.Answer[0].Header().Ttl + if defaultTtl != int(recTtl) { + t.Errorf("Test %d: expecting TTL=%d, got TTL=%d", i, defaultTtl, recTtl) + } + } +} + +// TestCacheSeparation verifies whether the cache maintains separation for specific DNS query types and options. +func TestCacheSeparation(t *testing.T) { + now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017") + utc := now.UTC() + + testCases := []struct { + name string + initial test.Case + query test.Case + expectCached bool // if a cache entry should be found before inserting + }{ + { + name: "query type should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeAAAA, + }, + }, + { + name: "DO bit should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + Do: true, + }, + }, + { + name: "CD bit should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + CheckingDisabled: true, + }, + }, + { + name: "CD bit and DO bit should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + CheckingDisabled: true, + Do: true, + }, + }, + { + name: "CD bit, DO bit, and query type should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeMX, + CheckingDisabled: true, + Do: true, + }, + }, + { + name: "authoritative answer bit should NOT be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + Authoritative: true, + }, + expectCached: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := New() + crr := &ResponseWriter{ResponseWriter: nil, Cache: c} + + // Insert initial cache entry + m := tc.initial.Msg() + m = cacheMsg(m, tc.initial) + state := request.Request{W: &test.ResponseWriter{}, Req: m} + + mt, _ := response.Typify(m, utc) + valid, k := key(state.Name(), m, mt, state.Do(), state.Req.CheckingDisabled) + + if valid { + // Insert cache entry + crr.set(m, k, mt, c.pttl) + } + + // Attempt to retrieve cache entry + m = tc.query.Msg() + m = cacheMsg(m, tc.query) + state = request.Request{W: &test.ResponseWriter{}, Req: m} + + item := c.getIgnoreTTL(time.Now().UTC(), state, "dns://:53") + found := item != nil + + if !tc.expectCached && found { + t.Fatal("Found cache message should that should not exist prior to inserting") + } + if tc.expectCached && !found { + t.Fatal("Did not find cache message that should exist prior to inserting") + } + }) + } +} + +// wildcardMetadataBackend mocks a backend that responds with a response for qname synthesized by wildcard +// and sets the zone/wildcard metadata value +func wildcardMetadataBackend(qname, wildcard string) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + m.Answer = []dns.RR{test.A(qname + " 300 IN A 127.0.0.1")} + metadata.SetValueFunc(ctx, "zone/wildcard", func() string { + return wildcard + }) + w.WriteMsg(m) + + return dns.RcodeSuccess, nil + }) +} diff --git a/dnssec.go b/dnssec.go new file mode 100644 index 0000000..da7e1e9 --- /dev/null +++ b/dnssec.go @@ -0,0 +1,24 @@ +package cache + +import "github.com/miekg/dns" + +// filterRRSlice filters out OPT RRs, and sets all RR TTLs to ttl. +// If dup is true the RRs in rrs are _copied_ before adjusting their +// TTL and the slice of copied RRs is returned. +func filterRRSlice(rrs []dns.RR, ttl uint32, dup bool) []dns.RR { + j := 0 + rs := make([]dns.RR, len(rrs)) + for _, r := range rrs { + if r.Header().Rrtype == dns.TypeOPT { + continue + } + if dup { + rs[j] = dns.Copy(r) + } else { + rs[j] = r + } + rs[j].Header().Ttl = ttl + j++ + } + return rs[:j] +} diff --git a/dnssec_test.go b/dnssec_test.go new file mode 100644 index 0000000..b73d52c --- /dev/null +++ b/dnssec_test.go @@ -0,0 +1,126 @@ +package cache + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestResponseWithDNSSEC(t *testing.T) { + // We do 2 queries, one where we want non-dnssec and one with dnssec and check the responses in each of them + var tcs = []test.Case{ + { + Qname: "invent.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + }, + }, + { + Qname: "invent.example.org.", Qtype: dns.TypeA, + Do: true, + AuthenticatedData: true, + Answer: []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+"), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9"), + }, + }, + } + + c := New() + c.Next = dnssecHandler() + + for i, tc := range tcs { + m := tc.Msg() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(context.TODO(), rec, m) + if tc.AuthenticatedData != rec.Msg.AuthenticatedData { + t.Errorf("Test %d, expected AuthenticatedData=%v", i, tc.AuthenticatedData) + } + if err := test.Section(tc, test.Answer, rec.Msg.Answer); err != nil { + t.Errorf("Test %d, expected no error, got %s", i, err) + } + } + + // now do the reverse + c = New() + c.Next = dnssecHandler() + + for i, tc := range []test.Case{tcs[1], tcs[0]} { + m := tc.Msg() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(context.TODO(), rec, m) + if err := test.Section(tc, test.Answer, rec.Msg.Answer); err != nil { + t.Errorf("Test %d, expected no error, got %s", i, err) + } + } +} + +func dnssecHandler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + state := request.Request{W: &test.ResponseWriter{}, Req: r} + + m.AuthenticatedData = true + // If query has the DO bit, then send DNSSEC responses (RRSIGs) + if state.Do() { + m.Answer = make([]dns.RR, 4) + m.Answer[0] = test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org.") + m.Answer[1] = test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+") + m.Answer[2] = test.A("leptone.example.org. 1781 IN A 195.201.182.103") + m.Answer[3] = test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9") + } else { + m.Answer = make([]dns.RR, 2) + m.Answer[0] = test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org.") + m.Answer[1] = test.A("leptone.example.org. 1781 IN A 195.201.182.103") + } + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +func TestFilterRRSlice(t *testing.T) { + rrs := []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+"), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9"), + } + + filter1 := filterRRSlice(rrs, 0, false) + if len(filter1) != 4 { + t.Errorf("Expected 4 RRs after filtering, got %d", len(filter1)) + } + rrsig := 0 + for _, f := range filter1 { + if f.Header().Rrtype == dns.TypeRRSIG { + rrsig++ + } + } + if rrsig != 2 { + t.Errorf("Expected 2 RRSIGs after filtering, got %d", rrsig) + } + + filter2 := filterRRSlice(rrs, 0, false) + if len(filter2) != 4 { + t.Errorf("Expected 4 RRs after filtering, got %d", len(filter2)) + } + rrsig = 0 + for _, f := range filter2 { + if f.Header().Rrtype == dns.TypeRRSIG { + rrsig++ + } + } + if rrsig != 2 { + t.Errorf("Expected 2 RRSIGs after filtering, got %d", rrsig) + } +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..cd18fda --- /dev/null +++ b/error_test.go @@ -0,0 +1,38 @@ +package cache + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestFormErr(t *testing.T) { + c := New() + c.Next = formErrHandler() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(context.TODO(), rec, req) + + if c.pcache.Len() != 0 { + t.Errorf("Cached %s, while reply had %d", "example.org.", rec.Msg.Rcode) + } +} + +// formErrHandler is a fake plugin implementation which returns a FORMERR for a reply. +func formErrHandler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.net.", dns.TypeA) + m.Rcode = dns.RcodeFormatError + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} diff --git a/freq/freq.go b/freq/freq.go new file mode 100644 index 0000000..f545f22 --- /dev/null +++ b/freq/freq.go @@ -0,0 +1,55 @@ +// Package freq keeps track of last X seen events. The events themselves are not stored +// here. So the Freq type should be added next to the thing it is tracking. +package freq + +import ( + "sync" + "time" +) + +// Freq tracks the frequencies of things. +type Freq struct { + // Last time we saw a query for this element. + last time.Time + // Number of this in the last time slice. + hits int + + sync.RWMutex +} + +// New returns a new initialized Freq. +func New(t time.Time) *Freq { + return &Freq{last: t, hits: 0} +} + +// Update updates the number of hits. Last time seen will be set to now. +// If the last time we've seen this entity is within now - d, we increment hits, otherwise +// we reset hits to 1. It returns the number of hits. +func (f *Freq) Update(d time.Duration, now time.Time) int { + earliest := now.Add(-1 * d) + f.Lock() + defer f.Unlock() + if f.last.Before(earliest) { + f.last = now + f.hits = 1 + return f.hits + } + f.last = now + f.hits++ + return f.hits +} + +// Hits returns the number of hits that we have seen, according to the updates we have done to f. +func (f *Freq) Hits() int { + f.RLock() + defer f.RUnlock() + return f.hits +} + +// Reset resets f to time t and hits to hits. +func (f *Freq) Reset(t time.Time, hits int) { + f.Lock() + defer f.Unlock() + f.last = t + f.hits = hits +} diff --git a/freq/freq_test.go b/freq/freq_test.go new file mode 100644 index 0000000..fc6042c --- /dev/null +++ b/freq/freq_test.go @@ -0,0 +1,37 @@ +package freq + +import ( + "testing" + "time" +) + +func TestFreqUpdate(t *testing.T) { + now := time.Now().UTC() + f := New(now) + window := 1 * time.Minute + + f.Update(window, time.Now().UTC()) + f.Update(window, time.Now().UTC()) + f.Update(window, time.Now().UTC()) + hitsCheck(t, f, 3) + + f.Reset(now, 0) + history := time.Now().UTC().Add(-3 * time.Minute) + f.Update(window, history) + hitsCheck(t, f, 1) +} + +func TestReset(t *testing.T) { + f := New(time.Now().UTC()) + f.Update(1*time.Minute, time.Now().UTC()) + hitsCheck(t, f, 1) + f.Reset(time.Now().UTC(), 0) + hitsCheck(t, f, 0) +} + +func hitsCheck(t *testing.T, f *Freq, expected int) { + t.Helper() + if x := f.Hits(); x != expected { + t.Fatalf("Expected hits to be %d, got %d", expected, x) + } +} diff --git a/fuzz.go b/fuzz.go new file mode 100644 index 0000000..43f4d26 --- /dev/null +++ b/fuzz.go @@ -0,0 +1,12 @@ +//go:build gofuzz + +package cache + +import ( + "github.com/coredns/coredns/plugin/pkg/fuzz" +) + +// Fuzz fuzzes cache. +func Fuzz(data []byte) int { + return fuzz.Do(New(), data) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..46ee3a3 --- /dev/null +++ b/go.mod @@ -0,0 +1,41 @@ +module git.ohea.xyz/restitux/coredns-host-specific-cache + +go 1.24.4 + +require ( + github.com/coredns/caddy v1.1.2-0.20241029205200-8de985351a98 + github.com/coredns/coredns v1.12.2 + github.com/miekg/dns v1.1.66 + github.com/prometheus/client_golang v1.22.0 +) + +require ( + github.com/apparentlymart/go-cidr v1.1.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect + github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/onsi/ginkgo/v2 v2.22.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.64.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/quic-go/quic-go v0.52.0 // indirect + go.uber.org/automaxprocs v1.6.0 // indirect + go.uber.org/mock v0.5.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/sync v0.14.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect + golang.org/x/tools v0.33.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9 // indirect + google.golang.org/grpc v1.72.2 // indirect + google.golang.org/protobuf v1.36.6 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5e3d95f --- /dev/null +++ b/go.sum @@ -0,0 +1,106 @@ +github.com/apparentlymart/go-cidr v1.1.0 h1:2mAhrMoF+nhXqxTzSZMUzDHkLjmIHC+Zzn4tdgBZjnU= +github.com/apparentlymart/go-cidr v1.1.0/go.mod h1:EBcsNrHc3zQeuaeCeCtQruQm+n9/YjEn/vI25Lg7Gwc= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/coredns/caddy v1.1.2-0.20241029205200-8de985351a98 h1:c+Epklw9xk6BZ1OFBPWLA2PcL8QalKvl3if8CP9x8uw= +github.com/coredns/caddy v1.1.2-0.20241029205200-8de985351a98/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4= +github.com/coredns/coredns v1.12.2 h1:G4oDfi340zlVsriZ8nYiUemiQIew7nqOO+QPvPxIA4Y= +github.com/coredns/coredns v1.12.2/go.mod h1:GFz31oVOfCyMArFoypfu1SoaFoNkbdh6lDxtF1B6vfU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU= +github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= +github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/miekg/dns v1.1.66 h1:FeZXOS3VCVsKnEAd+wBkjMC3D2K+ww66Cq3VnCINuJE= +github.com/miekg/dns v1.1.66/go.mod h1:jGFzBsSNbJw6z1HYut1RKBKHA9PBdxeHrZG8J+gC2WE= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/onsi/ginkgo/v2 v2.22.1 h1:QW7tbJAUDyVDVOM5dFa7qaybo+CRfR7bemlQUN6Z8aM= +github.com/onsi/ginkgo/v2 v2.22.1/go.mod h1:S6aTpoRsSq2cZOd+pssHAlKW/Q/jZt6cPrPlnj4a1xM= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4= +github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47pA= +github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= +go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= +go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce1EK0Gyvahk= +go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= +go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= +go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9 h1:IkAfh6J/yllPtpYFU0zZN1hUPYdT0ogkBT/9hMxHjvg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.72.2 h1:TdbGzwb82ty4OusHWepvFWGLgIbNo1/SUynEN0ssqv8= +google.golang.org/grpc v1.72.2/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..0a5ce52 --- /dev/null +++ b/handler.go @@ -0,0 +1,158 @@ +package cache + +import ( + "context" + "math" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ServeDNS implements the plugin.Handler interface. +func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rc := r.Copy() // We potentially modify r, to prevent other plugins from seeing this (r is a pointer), copy r into rc. + state := request.Request{W: w, Req: rc} + do := state.Do() + cd := r.CheckingDisabled + ad := r.AuthenticatedData + + zone := plugin.Zones(c.Zones).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, rc) + } + + now := c.now().UTC() + server := metrics.WithServer(ctx) + + // On cache refresh, we will just use the DO bit from the incoming query for the refresh since we key our cache + // with the query DO bit. That means two separate cache items for the query DO bit true or false. In the situation + // in which upstream doesn't support DNSSEC, the two cache items will effectively be the same. Regardless, any + // DNSSEC RRs in the response are written to cache with the response. + + i := c.getIgnoreTTL(now, state, server) + if i == nil { + crr := &ResponseWriter{ + ResponseWriter: w, Cache: c, state: state, server: server, do: do, ad: ad, cd: cd, + nexcept: c.nexcept, pexcept: c.pexcept, wildcardFunc: wildcardFunc(ctx), + } + return c.doRefresh(ctx, state, crr) + } + ttl := i.ttl(now) + if ttl < 0 { + // serve stale behavior + if c.verifyStale { + crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do, cd: cd} + cw := newVerifyStaleResponseWriter(crr) + ret, err := c.doRefresh(ctx, state, cw) + if cw.refreshed { + return ret, err + } + } + + // Adjust the time to get a 0 TTL in the reply built from a stale item. + now = now.Add(time.Duration(ttl) * time.Second) + if !c.verifyStale { + cw := newPrefetchResponseWriter(server, state, c) + go c.doPrefetch(ctx, state, cw, i, now) + } + servedStale.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + } else if c.shouldPrefetch(i, now) { + cw := newPrefetchResponseWriter(server, state, c) + go c.doPrefetch(ctx, state, cw, i, now) + } + + if i.wildcard != "" { + // Set wildcard source record name to metadata + metadata.SetValueFunc(ctx, "zone/wildcard", func() string { + return i.wildcard + }) + } + + if c.keepttl { + // If keepttl is enabled we fake the current time to the stored + // one so that we always get the original TTL + now = i.stored + } + resp := i.toMsg(r, now, do, ad) + w.WriteMsg(resp) + return dns.RcodeSuccess, nil +} + +func wildcardFunc(ctx context.Context) func() string { + return func() string { + // Get wildcard source record name from metadata + if f := metadata.ValueFunc(ctx, "zone/wildcard"); f != nil { + return f() + } + return "" + } +} + +func (c *Cache) doPrefetch(ctx context.Context, state request.Request, cw *ResponseWriter, i *item, now time.Time) { + cachePrefetches.WithLabelValues(cw.server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + c.doRefresh(ctx, state, cw) + + // When prefetching we loose the item i, and with it the frequency + // that we've gathered sofar. See we copy the frequencies info back + // into the new item that was stored in the cache. + if i1 := c.exists(state); i1 != nil { + i1.Reset(now, i.Hits()) + } +} + +func (c *Cache) doRefresh(ctx context.Context, state request.Request, cw dns.ResponseWriter) (int, error) { + return plugin.NextOrFailure(c.Name(), c.Next, ctx, cw, state.Req) +} + +func (c *Cache) shouldPrefetch(i *item, now time.Time) bool { + if c.prefetch <= 0 { + return false + } + i.Update(c.duration, now) + threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL))) + return i.Hits() >= c.prefetch && i.ttl(now) <= threshold +} + +// Name implements the Handler interface. +func (c *Cache) Name() string { return "cache" } + +// getIgnoreTTL unconditionally returns an item if it exists in the cache. +func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item { + k := hash(state.Name(), state.IP(), state.QType(), state.Do(), state.Req.CheckingDisabled) + cacheRequests.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + + if i, ok := c.ncache.Get(k); ok { + itm := i.(*item) + ttl := itm.ttl(now) + if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { + cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel, c.viewMetricLabel).Inc() + return i.(*item) + } + } + if i, ok := c.pcache.Get(k); ok { + itm := i.(*item) + ttl := itm.ttl(now) + if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { + cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel, c.viewMetricLabel).Inc() + return i.(*item) + } + } + cacheMisses.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + return nil +} + +func (c *Cache) exists(state request.Request) *item { + k := hash(state.Name(), state.IP(), state.QType(), state.Do(), state.Req.CheckingDisabled) + if i, ok := c.ncache.Get(k); ok { + return i.(*item) + } + if i, ok := c.pcache.Get(k); ok { + return i.(*item) + } + return nil +} diff --git a/item.go b/item.go new file mode 100644 index 0000000..c5aeccd --- /dev/null +++ b/item.go @@ -0,0 +1,107 @@ +package cache + +import ( + "strings" + "time" + + "github.com/coredns/coredns/plugin/cache/freq" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type item struct { + Name string + QType uint16 + Rcode int + AuthenticatedData bool + RecursionAvailable bool + Answer []dns.RR + Ns []dns.RR + Extra []dns.RR + wildcard string + + origTTL uint32 + stored time.Time + + *freq.Freq +} + +func newItem(m *dns.Msg, now time.Time, d time.Duration) *item { + i := new(item) + if len(m.Question) != 0 { + i.Name = m.Question[0].Name + i.QType = m.Question[0].Qtype + } + i.Rcode = m.Rcode + i.AuthenticatedData = m.AuthenticatedData + i.RecursionAvailable = m.RecursionAvailable + i.Answer = m.Answer + i.Ns = m.Ns + i.Extra = make([]dns.RR, len(m.Extra)) + // Don't copy OPT records as these are hop-by-hop. + j := 0 + for _, e := range m.Extra { + if e.Header().Rrtype == dns.TypeOPT { + continue + } + i.Extra[j] = e + j++ + } + i.Extra = i.Extra[:j] + + i.origTTL = uint32(d.Seconds()) + i.stored = now.UTC() + + i.Freq = new(freq.Freq) + + return i +} + +// toMsg turns i into a message, it tailors the reply to m. +// The Authoritative bit should be set to 0, but some client stub resolver implementations, most notably, +// on some legacy systems(e.g. ubuntu 14.04 with glib version 2.20), low-level glibc function `getaddrinfo` +// useb by Python/Ruby/etc.. will discard answers that do not have this bit set. +// So we're forced to always set this to 1; regardless if the answer came from the cache or not. +// On newer systems(e.g. ubuntu 16.04 with glib version 2.23), this issue is resolved. +// So we may set this bit back to 0 in the future ? +func (i *item) toMsg(m *dns.Msg, now time.Time, do bool, ad bool) *dns.Msg { + m1 := new(dns.Msg) + m1.SetReply(m) + + // Set this to true as some DNS clients discard the *entire* packet when it's non-authoritative. + // This is probably not according to spec, but the bit itself is not super useful as this point, so + // just set it to true. + m1.Authoritative = true + m1.AuthenticatedData = i.AuthenticatedData + if !do && !ad { + // When DNSSEC was not wanted, it can't be authenticated data. + // However, retain the AD bit if the requester set the AD bit, per RFC6840 5.7-5.8 + m1.AuthenticatedData = false + } + m1.RecursionAvailable = i.RecursionAvailable + m1.Rcode = i.Rcode + + m1.Answer = make([]dns.RR, len(i.Answer)) + m1.Ns = make([]dns.RR, len(i.Ns)) + m1.Extra = make([]dns.RR, len(i.Extra)) + + ttl := uint32(i.ttl(now)) + m1.Answer = filterRRSlice(i.Answer, ttl, true) + m1.Ns = filterRRSlice(i.Ns, ttl, true) + m1.Extra = filterRRSlice(i.Extra, ttl, true) + + return m1 +} + +func (i *item) ttl(now time.Time) int { + ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) + return ttl +} + +func (i *item) matches(state request.Request) bool { + if state.QType() == i.QType && strings.EqualFold(state.QName(), i.Name) { + return true + } + return false +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 0000000..220b206 --- /dev/null +++ b/log_test.go @@ -0,0 +1,5 @@ +package cache + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/metrics.go b/metrics.go new file mode 100644 index 0000000..93f0080 --- /dev/null +++ b/metrics.go @@ -0,0 +1,67 @@ +package cache + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // cacheSize is total elements in the cache by cache type. + cacheSize = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "entries", + Help: "The number of elements in the cache.", + }, []string{"server", "type", "zones", "view"}) + // cacheRequests is a counter of all requests through the cache. + cacheRequests = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "requests_total", + Help: "The count of cache requests.", + }, []string{"server", "zones", "view"}) + // cacheHits is counter of cache hits by cache type. + cacheHits = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "hits_total", + Help: "The count of cache hits.", + }, []string{"server", "type", "zones", "view"}) + // cacheMisses is the counter of cache misses. - Deprecated + cacheMisses = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "misses_total", + Help: "The count of cache misses. Deprecated, derive misses from cache hits/requests counters.", + }, []string{"server", "zones", "view"}) + // cachePrefetches is the number of time the cache has prefetched a cached item. + cachePrefetches = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "prefetch_total", + Help: "The number of times the cache has prefetched a cached item.", + }, []string{"server", "zones", "view"}) + // cacheDrops is the number responses that are not cached, because the reply is malformed. + cacheDrops = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "drops_total", + Help: "The number responses that are not cached, because the reply is malformed.", + }, []string{"server", "zones", "view"}) + // servedStale is the number of requests served from stale cache entries. + servedStale = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "served_stale_total", + Help: "The number of requests served from stale cache entries.", + }, []string{"server", "zones", "view"}) + // evictions is the counter of cache evictions. + evictions = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "host_specific_cache", + Name: "evictions_total", + Help: "The count of cache evictions.", + }, []string{"server", "type", "zones", "view"}) +) diff --git a/prefetch_test.go b/prefetch_test.go new file mode 100644 index 0000000..3085fe0 --- /dev/null +++ b/prefetch_test.go @@ -0,0 +1,228 @@ +package cache + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestPrefetch(t *testing.T) { + tests := []struct { + qname string + ttl int + prefetch int + verifications []verification + }{ + { + qname: "hits.reset.example.org.", + ttl: 80, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "hits.reset.example.org. 80 IN A 127.0.0.1", + fetch: true, // Initial fetch + }, + { + after: 73 * time.Second, + answer: "hits.reset.example.org. 7 IN A 127.0.0.1", + fetch: true, // Triggers prefetch with 7 TTL (10% of 80 = 8 TTL threshold) + }, + { + after: 80 * time.Second, + answer: "hits.reset.example.org. 73 IN A 127.0.0.2", + }, + }, + }, + { + qname: "short.ttl.example.org.", + ttl: 5, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "short.ttl.example.org. 5 IN A 127.0.0.1", + fetch: true, + }, + { + after: 1 * time.Second, + answer: "short.ttl.example.org. 4 IN A 127.0.0.1", + }, + { + after: 4 * time.Second, + answer: "short.ttl.example.org. 1 IN A 127.0.0.1", + fetch: true, + }, + { + after: 5 * time.Second, + answer: "short.ttl.example.org. 4 IN A 127.0.0.2", + }, + }, + }, + { + qname: "no.prefetch.example.org.", + ttl: 30, + prefetch: 0, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "no.prefetch.example.org. 30 IN A 127.0.0.1", + fetch: true, + }, + { + after: 15 * time.Second, + answer: "no.prefetch.example.org. 15 IN A 127.0.0.1", + }, + { + after: 29 * time.Second, + answer: "no.prefetch.example.org. 1 IN A 127.0.0.1", + }, + { + after: 30 * time.Second, + answer: "no.prefetch.example.org. 30 IN A 127.0.0.2", + fetch: true, + }, + }, + }, + { + // tests whether cache prefetches with the do bit + qname: "do.prefetch.example.org.", + ttl: 80, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "do.prefetch.example.org. 80 IN A 127.0.0.1", + do: true, + fetch: true, + }, + { + after: 73 * time.Second, + answer: "do.prefetch.example.org. 7 IN A 127.0.0.1", + do: true, + fetch: true, + }, + { + after: 80 * time.Second, + answer: "do.prefetch.example.org. 73 IN A 127.0.0.2", + do: true, + }, + { + // Should be 127.0.0.3 as 127.0.0.2 was the prefetch WITH do bit + after: 80 * time.Second, + answer: "do.prefetch.example.org. 80 IN A 127.0.0.3", + fetch: true, + }, + }, + }, + { + // tests whether cache prefetches with the cd bit + qname: "cd.prefetch.example.org.", + ttl: 80, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "cd.prefetch.example.org. 80 IN A 127.0.0.1", + cd: true, + fetch: true, + }, + { + after: 73 * time.Second, + answer: "cd.prefetch.example.org. 7 IN A 127.0.0.1", + cd: true, + fetch: true, + }, + { + after: 80 * time.Second, + answer: "cd.prefetch.example.org. 73 IN A 127.0.0.2", + cd: true, + }, + { + // Should be 127.0.0.3 as 127.0.0.2 was the prefetch WITH cd bit + after: 80 * time.Second, + answer: "cd.prefetch.example.org. 80 IN A 127.0.0.3", + fetch: true, + }, + }, + }, + } + + t0, err := time.Parse(time.RFC3339, "2018-01-01T14:00:00+00:00") + if err != nil { + t.Fatal(err) + } + for _, tt := range tests { + t.Run(tt.qname, func(t *testing.T) { + fetchc := make(chan struct{}, 1) + + c := New() + c.Next = prefetchHandler(tt.qname, tt.ttl, fetchc) + c.prefetch = tt.prefetch + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + for _, v := range tt.verifications { + c.now = func() time.Time { return t0.Add(v.after) } + + req := new(dns.Msg) + req.SetQuestion(tt.qname, dns.TypeA) + req.CheckingDisabled = v.cd + req.SetEdns0(512, v.do) + + c.ServeDNS(context.TODO(), rec, req) + if v.fetch { + select { + case <-fetchc: + // Prefetch handler was called. + case <-time.After(time.Second): + t.Fatalf("After %s: want request to trigger a prefetch", v.after) + } + } + if want, got := dns.RcodeSuccess, rec.Rcode; want != got { + t.Errorf("After %s: want rcode %d, got %d", v.after, want, got) + } + if want, got := 1, len(rec.Msg.Answer); want != got { + t.Errorf("After %s: want %d answer RR, got %d", v.after, want, got) + } + if want, got := test.A(v.answer).String(), rec.Msg.Answer[0].String(); want != got { + t.Errorf("After %s: want answer %s, got %s", v.after, want, got) + } + } + }) + } +} + +type verification struct { + after time.Duration + answer string + do bool + cd bool + // fetch defines whether a request is sent to the next handler. + fetch bool +} + +// prefetchHandler is a fake plugin implementation which returns a single A +// record with the given qname and ttl. The returned IP address starts at +// 127.0.0.1 and is incremented on every request. +func prefetchHandler(qname string, ttl int, fetchc chan struct{}) plugin.Handler { + i := 0 + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + i++ + m := new(dns.Msg) + m.SetQuestion(qname, dns.TypeA) + m.Response = true + m.Answer = append(m.Answer, test.A(fmt.Sprintf("%s %d IN A 127.0.0.%d", qname, ttl, i))) + + w.WriteMsg(m) + fetchc <- struct{}{} + return dns.RcodeSuccess, nil + }) +} diff --git a/setup.go b/setup.go new file mode 100644 index 0000000..363890c --- /dev/null +++ b/setup.go @@ -0,0 +1,261 @@ +package cache + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/cache" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +var log = clog.NewWithPlugin("host_specific_cache") + +func init() { plugin.Register("host_specific_cache", setup) } + +func setup(c *caddy.Controller) error { + ca, err := cacheParse(c) + if err != nil { + return plugin.Error("host_specific_cache", err) + } + + c.OnStartup(func() error { + ca.viewMetricLabel = dnsserver.GetConfig(c).ViewName + return nil + }) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + ca.Next = next + return ca + }) + + return nil +} + +func cacheParse(c *caddy.Controller) (*Cache, error) { + ca := New() + + j := 0 + for c.Next() { + if j > 0 { + return nil, plugin.ErrOnce + } + j++ + + // cache [ttl] [zones..] + args := c.RemainingArgs() + if len(args) > 0 { + // first args may be just a number, then it is the ttl, if not it is a zone + ttl, err := strconv.Atoi(args[0]) + if err == nil { + // Reserve 0 (and smaller for future things) + if ttl <= 0 { + return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", ttl) + } + ca.pttl = time.Duration(ttl) * time.Second + ca.nttl = time.Duration(ttl) * time.Second + args = args[1:] + } + } + origins := plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys) + + // Refinements? In an extra block. + for c.NextBlock() { + switch c.Val() { + // first number is cap, second is an new ttl + case Success: + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + pcap, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + ca.pcap = pcap + if len(args) > 1 { + pttl, err := strconv.Atoi(args[1]) + if err != nil { + return nil, err + } + // Reserve 0 (and smaller for future things) + if pttl <= 0 { + return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", pttl) + } + ca.pttl = time.Duration(pttl) * time.Second + if len(args) > 2 { + minpttl, err := strconv.Atoi(args[2]) + if err != nil { + return nil, err + } + // Reserve < 0 + if minpttl < 0 { + return nil, fmt.Errorf("cache min TTL can not be negative: %d", minpttl) + } + ca.minpttl = time.Duration(minpttl) * time.Second + } + } + case Denial: + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + ncap, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + ca.ncap = ncap + if len(args) > 1 { + nttl, err := strconv.Atoi(args[1]) + if err != nil { + return nil, err + } + // Reserve 0 (and smaller for future things) + if nttl <= 0 { + return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", nttl) + } + ca.nttl = time.Duration(nttl) * time.Second + if len(args) > 2 { + minnttl, err := strconv.Atoi(args[2]) + if err != nil { + return nil, err + } + // Reserve < 0 + if minnttl < 0 { + return nil, fmt.Errorf("cache min TTL can not be negative: %d", minnttl) + } + ca.minnttl = time.Duration(minnttl) * time.Second + } + } + case "prefetch": + args := c.RemainingArgs() + if len(args) == 0 || len(args) > 3 { + return nil, c.ArgErr() + } + amount, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + if amount < 0 { + return nil, fmt.Errorf("prefetch amount should be positive: %d", amount) + } + ca.prefetch = amount + + if len(args) > 1 { + dur, err := time.ParseDuration(args[1]) + if err != nil { + return nil, err + } + ca.duration = dur + } + if len(args) > 2 { + pct := args[2] + if x := pct[len(pct)-1]; x != '%' { + return nil, fmt.Errorf("last character of percentage should be `%%`, but is: %q", x) + } + pct = pct[:len(pct)-1] + + num, err := strconv.Atoi(pct) + if err != nil { + return nil, err + } + if num < 10 || num > 90 { + return nil, fmt.Errorf("percentage should fall in range [10, 90]: %d", num) + } + ca.percentage = num + } + + case "serve_stale": + args := c.RemainingArgs() + if len(args) > 2 { + return nil, c.ArgErr() + } + ca.staleUpTo = 1 * time.Hour + if len(args) > 0 { + d, err := time.ParseDuration(args[0]) + if err != nil { + return nil, err + } + if d < 0 { + return nil, errors.New("invalid negative duration for serve_stale") + } + ca.staleUpTo = d + } + ca.verifyStale = false + if len(args) > 1 { + mode := strings.ToLower(args[1]) + if mode != "immediate" && mode != "verify" { + return nil, fmt.Errorf("invalid value for serve_stale refresh mode: %s", mode) + } + ca.verifyStale = mode == "verify" + } + case "servfail": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + d, err := time.ParseDuration(args[0]) + if err != nil { + return nil, err + } + if d < 0 { + return nil, errors.New("invalid negative ttl for servfail") + } + if d > 5*time.Minute { + // RFC 2308 prohibits caching SERVFAIL longer than 5 minutes + return nil, errors.New("caching SERVFAIL responses over 5 minutes is not permitted") + } + ca.failttl = d + case "disable": + // disable [success|denial] [zones]... + args := c.RemainingArgs() + if len(args) < 1 { + return nil, c.ArgErr() + } + + var zones []string + if len(args) > 1 { + for _, z := range args[1:] { // args[1:] define the list of zones to disable + nz := plugin.Name(z).Normalize() + if nz == "" { + return nil, fmt.Errorf("invalid disabled zone: %s", z) + } + zones = append(zones, nz) + } + } else { + // if no zones specified, default to root + zones = []string{"."} + } + + switch args[0] { // args[0] defines which cache to disable + case Denial: + ca.nexcept = zones + case Success: + ca.pexcept = zones + default: + return nil, fmt.Errorf("cache type for disable must be %q or %q", Success, Denial) + } + case "keepttl": + args := c.RemainingArgs() + if len(args) != 0 { + return nil, c.ArgErr() + } + ca.keepttl = true + default: + return nil, c.ArgErr() + } + } + + ca.Zones = origins + ca.zonesMetricLabel = strings.Join(origins, ",") + ca.pcache = cache.New(ca.pcap) + ca.ncache = cache.New(ca.ncap) + } + + return ca, nil +} diff --git a/setup_test.go b/setup_test.go new file mode 100644 index 0000000..46ac5bd --- /dev/null +++ b/setup_test.go @@ -0,0 +1,262 @@ +package cache + +import ( + "fmt" + "testing" + "time" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedNcap int + expectedPcap int + expectedNttl time.Duration + expectedMinNttl time.Duration + expectedPttl time.Duration + expectedMinPttl time.Duration + expectedPrefetch int + }{ + {`cache`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache {}`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 10 + }`, false, defaultCap, 10, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 10 1800 30 + }`, false, defaultCap, 10, maxNTTL, minNTTL, 1800 * time.Second, 30 * time.Second, 0}, + {`cache example.nl { + success 10 + denial 10 15 + }`, false, 10, 10, 15 * time.Second, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 10 + denial 10 15 2 + }`, false, 10, 10, 15 * time.Second, 2 * time.Second, maxTTL, minTTL, 0}, + {`cache 25 example.nl { + success 10 + denial 10 15 + }`, false, 10, 10, 15 * time.Second, minNTTL, 25 * time.Second, minTTL, 0}, + {`cache 25 example.nl { + success 10 + denial 10 15 5 + }`, false, 10, 10, 15 * time.Second, 5 * time.Second, 25 * time.Second, minTTL, 0}, + {`cache aaa example.nl`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache { + prefetch 10 + }`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 10}, + + // fails + {`cache example.nl { + success + denial 10 15 + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 15 + denial aaa + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + positive 15 + negative aaa + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 0 example.nl`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache -1 example.nl`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 1 example.nl { + positive 0 + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 1 example.nl { + positive 0 + prefetch -1 + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 1 example.nl { + prefetch 0 blurp + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache + cache`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr && err != nil { + continue + } + + if ca.ncap != test.expectedNcap { + t.Errorf("Test %v: Expected ncap %v but found: %v", i, test.expectedNcap, ca.ncap) + } + if ca.pcap != test.expectedPcap { + t.Errorf("Test %v: Expected pcap %v but found: %v", i, test.expectedPcap, ca.pcap) + } + if ca.nttl != test.expectedNttl { + t.Errorf("Test %v: Expected nttl %v but found: %v", i, test.expectedNttl, ca.nttl) + } + if ca.minnttl != test.expectedMinNttl { + t.Errorf("Test %v: Expected minnttl %v but found: %v", i, test.expectedMinNttl, ca.minnttl) + } + if ca.pttl != test.expectedPttl { + t.Errorf("Test %v: Expected pttl %v but found: %v", i, test.expectedPttl, ca.pttl) + } + if ca.minpttl != test.expectedMinPttl { + t.Errorf("Test %v: Expected minpttl %v but found: %v", i, test.expectedMinPttl, ca.minpttl) + } + if ca.prefetch != test.expectedPrefetch { + t.Errorf("Test %v: Expected prefetch %v but found: %v", i, test.expectedPrefetch, ca.prefetch) + } + } +} + +func TestServeStale(t *testing.T) { + tests := []struct { + input string + shouldErr bool + staleUpTo time.Duration + verifyStale bool + }{ + {"serve_stale", false, 1 * time.Hour, false}, + {"serve_stale 20m", false, 20 * time.Minute, false}, + {"serve_stale 1h20m", false, 80 * time.Minute, false}, + {"serve_stale 0m", false, 0, false}, + {"serve_stale 0", false, 0, false}, + {"serve_stale 0 verify", false, 0, true}, + {"serve_stale 0 immediate", false, 0, false}, + {"serve_stale 0 VERIFY", false, 0, true}, + // fails + {"serve_stale 20", true, 0, false}, + {"serve_stale -20m", true, 0, false}, + {"serve_stale aa", true, 0, false}, + {"serve_stale 1m nono", true, 0, false}, + {"serve_stale 0 after nono", true, 0, false}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr && err != nil { + continue + } + if ca.staleUpTo != test.staleUpTo { + t.Errorf("Test %v: Expected stale %v but found: %v", i, test.staleUpTo, ca.staleUpTo) + } + } +} + +func TestServfail(t *testing.T) { + tests := []struct { + input string + shouldErr bool + failttl time.Duration + }{ + {"servfail 1s", false, 1 * time.Second}, + {"servfail 5m", false, 5 * time.Minute}, + {"servfail 0s", false, 0}, + {"servfail 0", false, 0}, + // fails + {"servfail", true, minNTTL}, + {"servfail 6m", true, minNTTL}, + {"servfail 20", true, minNTTL}, + {"servfail -1s", true, minNTTL}, + {"servfail aa", true, minNTTL}, + {"servfail 1m invalid", true, minNTTL}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr && err != nil { + continue + } + if ca.failttl != test.failttl { + t.Errorf("Test %v: Expected stale %v but found: %v", i, test.failttl, ca.staleUpTo) + } + } +} + +func TestDisable(t *testing.T) { + tests := []struct { + input string + shouldErr bool + nexcept []string + pexcept []string + }{ + // positive + {"disable denial example.com example.org", false, []string{"example.com.", "example.org."}, nil}, + {"disable success example.com example.org", false, nil, []string{"example.com.", "example.org."}}, + {"disable denial", false, []string{"."}, nil}, + {"disable success", false, nil, []string{"."}}, + {"disable denial example.com example.org\ndisable success example.com example.org", false, + []string{"example.com.", "example.org."}, []string{"example.com.", "example.org."}}, + // negative + {"disable invalid example.com example.org", true, nil, nil}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr { + continue + } + if fmt.Sprintf("%v", test.nexcept) != fmt.Sprintf("%v", ca.nexcept) { + t.Errorf("Test %v: Expected %v but got: %v", i, test.nexcept, ca.nexcept) + } + if fmt.Sprintf("%v", test.pexcept) != fmt.Sprintf("%v", ca.pexcept) { + t.Errorf("Test %v: Expected %v but got: %v", i, test.pexcept, ca.pexcept) + } + } +} + +func TestKeepttl(t *testing.T) { + tests := []struct { + input string + shouldErr bool + }{ + // positive + {"keepttl", false}, + // negative + {"keepttl arg1", true}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr { + continue + } + if !ca.keepttl { + t.Errorf("Test %v: Expected keepttl enabled but disabled", i) + } + } +} diff --git a/spoof_test.go b/spoof_test.go new file mode 100644 index 0000000..20d7e8d --- /dev/null +++ b/spoof_test.go @@ -0,0 +1,82 @@ +package cache + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestSpoof(t *testing.T) { + // Send query for example.org, get reply for example.net; should not be cached. + c := New() + c.Next = spoofHandler(true) + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(context.TODO(), rec, req) + + qname := rec.Msg.Question[0].Name + if c.pcache.Len() != 0 { + t.Errorf("Cached %s, while reply had %s", "example.org.", qname) + } + + // qtype + c.Next = spoofHandlerType() + req.SetQuestion("example.org.", dns.TypeMX) + + c.ServeDNS(context.TODO(), rec, req) + + qtype := rec.Msg.Question[0].Qtype + if c.pcache.Len() != 0 { + t.Errorf("Cached %s type %d, while reply had %d", "example.org.", dns.TypeMX, qtype) + } +} + +func TestResponse(t *testing.T) { + // Send query for example.org, get reply for example.net; should not be cached. + c := New() + c.Next = spoofHandler(false) + + req := new(dns.Msg) + req.SetQuestion("example.net.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(context.TODO(), rec, req) + + if c.pcache.Len() != 0 { + t.Errorf("Cached %s, while reply had response set to %t", "example.net.", rec.Msg.Response) + } +} + +// spoofHandler is a fake plugin implementation which returns a single A records for example.org. The qname in the +// question section is set to example.NET (i.e. they *don't* match). +func spoofHandler(response bool) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.net.", dns.TypeA) + m.Response = response + m.Answer = []dns.RR{test.A("example.org. IN A 127.0.0.53")} + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +// spoofHandlerType is a fake plugin implementation which returns a single MX records for example.org. The qtype in the +// question section is set to A. +func spoofHandlerType() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + m.Response = true + m.Answer = []dns.RR{test.MX("example.org. IN MX 10 mail.example.org.")} + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +}