Skip to content

Commit 6cdaf1a

Browse files
committed
feat: server support write cache to file
1 parent c1abbaf commit 6cdaf1a

File tree

8 files changed

+173
-59
lines changed

8 files changed

+173
-59
lines changed

exec/server/main.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ func init() {
9999
if err := server.InitACMEAccount(); err != nil {
100100
log.Fatalf("[ERR] Failed init ACME account: %s", err)
101101
}
102+
103+
if err := server.InitCache(); err != nil {
104+
log.Fatalf("[ERR] Failed init server cache: %s", err)
105+
}
102106
}
103107

104108
func serveHttps() {
@@ -109,7 +113,7 @@ func serveHttps() {
109113

110114
for {
111115
cert := entry.Cert()
112-
certificate, err := tls.X509KeyPair(cert.Cert, cert.Key)
116+
certificate, err := tls.X509KeyPair(cert.FullChain, cert.Key)
113117
if err != nil {
114118
log.Fatalf("[ERR] Failed to load cert: %s", err)
115119
}

pkg/client/client.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,19 @@ func (r *CertDXClientDaemon) requestCert(domains []string) *types.HttpCertResp {
8282
return nil
8383
}
8484

85-
func (r *CertDXClientDaemon) certWatchDog(cert config.ClientCertification, onChanged ...func(cert, key []byte, c *config.ClientCertification)) {
86-
var currentCert, currentKey []byte
85+
func (r *CertDXClientDaemon) certWatchDog(cert config.ClientCertification, onChanged ...func(fullchain, key []byte, c *config.ClientCertification)) {
86+
var currentFullChain, currentKey []byte
8787
sleepTime := 1 * time.Hour // default sleep time
8888
for {
8989
log.Printf("[INF] Request cert %v", cert.Domains)
9090
resp := r.requestCert(cert.Domains)
9191
if resp != nil {
9292
sleepTime = resp.RenewTimeLeft / 4
93-
if !bytes.Equal(currentCert, resp.Cert) || !bytes.Equal(currentKey, resp.Key) {
93+
if !bytes.Equal(currentFullChain, resp.FullChain) || !bytes.Equal(currentKey, resp.Key) {
9494
log.Printf("[INF] Notify cert %v changed", cert.Domains)
95-
currentCert, currentKey = resp.Cert, resp.Key
95+
currentFullChain, currentKey = resp.FullChain, resp.Key
9696
for _, handleFunc := range onChanged {
97-
handleFunc(resp.Cert, resp.Key, &cert)
97+
handleFunc(resp.FullChain, resp.Key, &cert)
9898
}
9999
} else {
100100
log.Printf("[INF] Cert %v not changed", cert.Domains)
@@ -105,10 +105,10 @@ func (r *CertDXClientDaemon) certWatchDog(cert config.ClientCertification, onCha
105105
}
106106
}
107107

108-
func writeCertAndDoCommand(cert, key []byte, c *config.ClientCertification) {
108+
func writeCertAndDoCommand(fullchain, key []byte, c *config.ClientCertification) {
109109
var doCommand, ce, ke bool
110110

111-
certPath, keyPath := c.GetCertAndKeyPath()
111+
certPath, keyPath := c.GetFullChainAndKeyPath()
112112
ce, err := checkFileAndCreate(certPath)
113113
if err != nil {
114114
goto ERR
@@ -120,7 +120,7 @@ func writeCertAndDoCommand(cert, key []byte, c *config.ClientCertification) {
120120
// if cert file is firstly created, don't do reload command
121121
doCommand = ce && ke
122122

123-
err = os.WriteFile(certPath, cert, 0o777)
123+
err = os.WriteFile(certPath, fullchain, 0o777)
124124
if err != nil {
125125
goto ERR
126126
}

pkg/config/config.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ type ClientCertification struct {
9999
ReloadCommand string `toml:"reloadCommand" json:"reload_command,omitempty"`
100100
}
101101

102-
func (c *ClientCertification) GetCertAndKeyPath() (cert, key string) {
103-
cert = path.Join(c.SavePath, fmt.Sprintf("%s.pem", c.Name))
102+
func (c *ClientCertification) GetFullChainAndKeyPath() (fullchain, key string) {
103+
fullchain = path.Join(c.SavePath, fmt.Sprintf("%s.pem", c.Name))
104104
key = path.Join(c.SavePath, fmt.Sprintf("%s.key", c.Name))
105105
return
106106
}

pkg/server/acme.go

-19
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"encoding/pem"
1313
"fmt"
1414
"os"
15-
"path"
1615
"strings"
1716
"time"
1817

@@ -93,24 +92,6 @@ func (u *MyUser) GetPrivateKey() crypto.PrivateKey {
9392
return u.Key
9493
}
9594

96-
func getPrivateKeySavePath(email string, ACMEProvider string) (string, error) {
97-
saveDir, err := os.Getwd()
98-
if err != nil {
99-
return "", err
100-
}
101-
saveDir = path.Join(saveDir, "private")
102-
keyName := fmt.Sprintf("%s_%s.key", email, ACMEProvider)
103-
104-
if _, err := os.Stat(saveDir); os.IsNotExist(err) {
105-
err := os.Mkdir(saveDir, 0o600)
106-
if err != nil {
107-
return "", fmt.Errorf("cannot create path: %s to save account key", saveDir)
108-
}
109-
}
110-
111-
return path.Join(saveDir, keyName), nil
112-
}
113-
11495
func RegisterAccount(ACMEProvider, Email, Kid, Hmac string) error {
11596
keyPath, err := getPrivateKeySavePath(Email, ACMEProvider)
11697
if err != nil {

pkg/server/http.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func handleCertReq(w *http.ResponseWriter, r *http.Request) {
6969
cert = cachedCert.Cert()
7070
resp, err = json.Marshal(&types.HttpCertResp{
7171
RenewTimeLeft: Config.ACME.RenewTimeLeftDuration,
72-
Cert: cert.Cert,
72+
FullChain: cert.FullChain,
7373
Key: cert.Key,
7474
})
7575
if err != nil {
@@ -78,7 +78,7 @@ func handleCertReq(w *http.ResponseWriter, r *http.Request) {
7878

7979
(*w).Header().Set("Content-Type", "application/json")
8080
(*w).Write(resp)
81-
log.Printf("[INF] Http sent cert: %v to: %s", cachedCert.Domains, r.RemoteAddr)
81+
log.Printf("[INF] Http sent cert: %v to: %s", cachedCert.domains, r.RemoteAddr)
8282
return
8383

8484
ERR:

pkg/server/path.go

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package server
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"path"
7+
)
8+
9+
func getPrivateKeySavePath(email string, ACMEProvider string) (string, error) {
10+
saveDir, err := os.Getwd()
11+
if err != nil {
12+
return "", err
13+
}
14+
saveDir = path.Join(saveDir, "private")
15+
keyName := fmt.Sprintf("%s_%s.key", email, ACMEProvider)
16+
17+
if _, err := os.Stat(saveDir); os.IsNotExist(err) {
18+
err := os.Mkdir(saveDir, 0o600)
19+
if err != nil {
20+
return "", fmt.Errorf("cannot create path: %s to save account key", saveDir)
21+
}
22+
}
23+
24+
return path.Join(saveDir, keyName), nil
25+
}
26+
27+
func getCacheSavePath() (cachePath string, exist bool) {
28+
saveDir, err := os.Getwd()
29+
if err != nil {
30+
return "", false
31+
}
32+
33+
cacheFile := path.Join(saveDir, "cache.json")
34+
if _, err := os.Stat(cacheFile); os.IsNotExist(err) {
35+
return cacheFile, false
36+
}
37+
38+
return cacheFile, true
39+
}

pkg/server/server.go

+116-26
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,139 @@
11
package server
22

33
import (
4+
"encoding/json"
5+
"fmt"
46
"log"
5-
"pkg.para.party/certdx/pkg/config"
6-
"pkg.para.party/certdx/pkg/utils"
7+
"os"
78
"slices"
89
"strings"
910
"sync"
1011
"sync/atomic"
1112
"time"
13+
14+
"pkg.para.party/certdx/pkg/config"
15+
"pkg.para.party/certdx/pkg/utils"
1216
)
1317

1418
type CertT struct {
15-
Cert, Key []byte
16-
ValidBefore time.Time
19+
FullChain []byte `json:"fullChain"`
20+
Key []byte `json:"key"`
21+
ValidBefore time.Time `json:"validBefore"`
1722
}
1823

19-
type ServerCertCacheEntry struct {
20-
Domains []string
24+
type ServerCacheFileEntry struct {
25+
Domains []string `json:"domains"`
26+
Cert CertT `json:"cert"`
27+
}
2128

22-
cert CertT
23-
Mutex sync.Mutex
29+
type ServerCertCacheEntry struct {
30+
domains []string
31+
cert CertT
32+
mutex sync.Mutex
2433

2534
Listening atomic.Bool
2635
Updated atomic.Pointer[chan struct{}]
2736
Stop atomic.Pointer[chan struct{}]
2837
}
2938

3039
type ServerCertCacheT struct {
31-
entrys []*ServerCertCacheEntry
32-
mutex sync.Mutex
40+
entrys []*ServerCertCacheEntry
41+
mutex sync.Mutex
42+
updated atomic.Pointer[chan struct{}]
3343
}
3444

3545
var ServerCertCache = &ServerCertCacheT{}
3646
var Config = &config.ServerConfigT{}
3747

48+
func InitCache() error {
49+
updated := make(chan struct{})
50+
ServerCertCache.updated.Store(&updated)
51+
52+
if err := loadCacheFile(); err != nil {
53+
return err
54+
}
55+
56+
go func() {
57+
for {
58+
<-*ServerCertCache.updated.Load()
59+
log.Printf("[INF] Write domains cache to file")
60+
61+
if err := writeCacheFile(); err != nil {
62+
log.Printf("[WRN] Write domains cache to file failed: %s", err)
63+
}
64+
}
65+
}()
66+
67+
return nil
68+
}
69+
70+
func loadCacheFile() error {
71+
cachePath, exist := getCacheSavePath()
72+
if !exist {
73+
return nil
74+
}
75+
76+
cfile, err := os.ReadFile(cachePath)
77+
if err != nil {
78+
return fmt.Errorf("open cache file failed: %w", err)
79+
}
80+
81+
var caches []ServerCacheFileEntry
82+
err = json.Unmarshal(cfile, &caches)
83+
if err != nil {
84+
return fmt.Errorf("unmarshal cache file failed: %w", err)
85+
}
86+
87+
for _, cache := range caches {
88+
entry := ServerCertCache.GetEntry(cache.Domains)
89+
entry.mutex.Lock()
90+
entry.cert = cache.Cert
91+
entry.mutex.Unlock()
92+
}
93+
94+
return nil
95+
}
96+
97+
func writeCacheFile() error {
98+
ServerCertCache.mutex.Lock()
99+
defer ServerCertCache.mutex.Unlock()
100+
101+
var caches []ServerCacheFileEntry
102+
for _, entry := range ServerCertCache.entrys {
103+
entry.mutex.Lock()
104+
cache := ServerCacheFileEntry{
105+
Domains: entry.domains,
106+
Cert: entry.cert,
107+
}
108+
entry.mutex.Unlock()
109+
caches = append(caches, cache)
110+
}
111+
112+
jsonBytes, err := json.Marshal(caches)
113+
if err != nil {
114+
return fmt.Errorf("failed marshal cache file: %w", err)
115+
}
116+
117+
cachePath, _ := getCacheSavePath()
118+
err = os.WriteFile(cachePath, jsonBytes, 0o600)
119+
if err != nil {
120+
return fmt.Errorf("failed write cache file: %w", err)
121+
}
122+
123+
return nil
124+
}
125+
38126
func (s *ServerCertCacheT) GetEntry(domains []string) *ServerCertCacheEntry {
39127
s.mutex.Lock()
40128
defer s.mutex.Unlock()
41129
for _, entry := range s.entrys {
42-
if utils.SameCert(domains, entry.Domains) {
130+
if utils.SameCert(domains, entry.domains) {
43131
return entry
44132
}
45133
}
46134

47135
entry := &ServerCertCacheEntry{
48-
Domains: domains,
136+
domains: domains,
49137
}
50138
entry.Listening.Store(false)
51139
updated := make(chan struct{})
@@ -57,16 +145,16 @@ func (s *ServerCertCacheT) GetEntry(domains []string) *ServerCertCacheEntry {
57145
}
58146

59147
func (c *ServerCertCacheEntry) Cert() CertT {
60-
c.Mutex.Lock()
61-
defer c.Mutex.Unlock()
148+
c.mutex.Lock()
149+
defer c.mutex.Unlock()
62150
return c.cert
63151
}
64152

65153
func (c *ServerCertCacheEntry) Renew(retry bool) (bool, error) {
66-
c.Mutex.Lock()
67-
defer c.Mutex.Unlock()
154+
c.mutex.Lock()
155+
defer c.mutex.Unlock()
68156

69-
log.Printf("[INF] Renew cert: %v", c.Domains)
157+
log.Printf("[INF] Renew cert: %v", c.domains)
70158
if !time.Now().Before(c.cert.ValidBefore) {
71159
newValidBefore := time.Now().Truncate(1 * time.Hour).Add(Config.ACME.CertLifeTimeDuration)
72160

@@ -75,27 +163,29 @@ func (c *ServerCertCacheEntry) Renew(retry bool) (bool, error) {
75163
return false, err
76164
}
77165

78-
var cert, key []byte
166+
var fullchain, key []byte
79167
if retry {
80-
cert, key, err = acme.RetryObtain(c.Domains, newValidBefore.Add(Config.ACME.RenewTimeLeftDuration))
168+
fullchain, key, err = acme.RetryObtain(c.domains, newValidBefore.Add(Config.ACME.RenewTimeLeftDuration))
81169
} else {
82-
cert, key, err = acme.Obtain(c.Domains, newValidBefore.Add(Config.ACME.RenewTimeLeftDuration))
170+
fullchain, key, err = acme.Obtain(c.domains, newValidBefore.Add(Config.ACME.RenewTimeLeftDuration))
83171
}
84172
if err != nil {
85173
return false, err
86174
}
87175

88176
c.cert = CertT{
89177
ValidBefore: newValidBefore,
90-
Cert: cert,
178+
FullChain: fullchain,
91179
Key: key,
92180
}
93181

94-
log.Printf("[INF] Obtained cert: %v", c.Domains)
182+
log.Printf("[INF] Obtained cert: %v", c.domains)
183+
newUpdated := make(chan struct{})
184+
close(*ServerCertCache.updated.Swap(&newUpdated))
95185
return true, nil
96186
}
97187

98-
log.Printf("[INF] Cert: %v not expired", c.Domains)
188+
log.Printf("[INF] Cert: %v not expired", c.domains)
99189
return false, nil
100190
}
101191

@@ -106,13 +196,13 @@ func (c *ServerCertCacheEntry) CertWatchDog() {
106196

107197
c.Listening.Store(true)
108198
for {
109-
log.Printf("[INF] Server renew: %v", c.Domains)
199+
log.Printf("[INF] Server renew: %v", c.domains)
110200
changed, err := c.Renew(true)
111201
if err != nil {
112-
log.Printf("[ERR] Failed renew cert %s: %s", c.Domains, err)
202+
log.Printf("[ERR] Failed renew cert %s: %s", c.domains, err)
113203
} else if changed {
114204
newUpdated := make(chan struct{})
115-
log.Printf("[INF] Notify cert %s updated", c.Domains)
205+
log.Printf("[INF] Notify cert %s updated", c.domains)
116206
close(*c.Updated.Swap(&newUpdated))
117207
}
118208

0 commit comments

Comments
 (0)