From 4f50ca2a521ea5090a5d6f251501e1abb472aa49 Mon Sep 17 00:00:00 2001 From: Ajabep Date: Mon, 26 Aug 2024 13:57:44 +0200 Subject: [PATCH 1/6] Typos Yeah, doc is made with //, not ///, we're not doing rust ^^ --- internal/httpproxy/proxy.go | 2 +- internal/tcpproxy/proxy.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index 3abe90d..840f920 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -40,7 +40,7 @@ func makeHandleHTTP(dest string, tlsConfig *tls.Config, reuseSockets bool) func( case "80": u.Scheme = "http" case "443": - u.Scheme = "http" + u.Scheme = "https" case "": switch u.Scheme { default: diff --git a/internal/tcpproxy/proxy.go b/internal/tcpproxy/proxy.go index 5b16b73..223116f 100644 --- a/internal/tcpproxy/proxy.go +++ b/internal/tcpproxy/proxy.go @@ -36,8 +36,8 @@ func newProxy(from, to string, tlsConfig *tls.Config) *proxy { } } +// Start the proxy. Is blocking! func (p *proxy) start(ctx context.Context) error { - /// Start the proxy. Is blocking! listener, err := net.Listen("tcp", p.from) if err != nil { From abaabcbd0a493493a814e6fbde81fe79d7e37169 Mon Sep 17 00:00:00 2001 From: Ajabep Date: Mon, 26 Aug 2024 13:55:42 +0200 Subject: [PATCH 2/6] Add unit tests and fixes --- .github/workflows/go.yml | 3 + .gitignore | 1 + internal/configuration/configuration.go | 16 + .../configurationtest/configuration.go | 133 ++++ .../configurationtest/configuration_test.go | 39 ++ tests/CertificateGenerator.go | 76 +++ tests/MainSupervisor.go | 100 +++ tests/SyncedUint.go | 28 + tests/TlsServerCounter.go | 238 ++++++++ unmtlsproxy_test.go | 570 ++++++++++++++++++ 10 files changed, 1204 insertions(+) create mode 100644 internal/configuration/configurationtest/configuration.go create mode 100644 internal/configuration/configurationtest/configuration_test.go create mode 100644 tests/CertificateGenerator.go create mode 100644 tests/MainSupervisor.go create mode 100644 tests/SyncedUint.go create mode 100644 tests/TlsServerCounter.go create mode 100644 unmtlsproxy_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 46b39f4..ee83ef0 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -29,6 +29,9 @@ jobs: with: go-version: '1.22' + - name: Vet + run: go vet ./... + - name: Build run: go build -v ./... diff --git a/.gitignore b/.gitignore index 90456a1..9228694 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ unmtlsproxy_win_386.exe unmtlsproxy unmtlsproxy.exe local_build.ps1 +__debug_bin* diff --git a/internal/configuration/configuration.go b/internal/configuration/configuration.go index 1856619..5d51b84 100644 --- a/internal/configuration/configuration.go +++ b/internal/configuration/configuration.go @@ -16,7 +16,9 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "net/url" "os" + "strconv" "go.aporeto.io/addedeffect/lombric" "go.aporeto.io/tg/tglib" @@ -55,6 +57,20 @@ func NewConfiguration() *Configuration { c := &Configuration{} lombric.Initialize(c) + listenUrl, err := url.Parse("http://" + c.ListenAddress) + if err != nil { + panic(err) + } + if port := listenUrl.Port(); port == "" { + panic("Invalid Listen format. Use `hostname:port`.") + } else if portInt, err := strconv.Atoi(port); err != nil { + panic(err) + } else if portInt <= 0 { + panic("Invalid Listening Port. Too low.") + } else if portInt > 65535 { + panic("Invalid Listening Port. Too High. We use TCP.") + } + if c.Mode == "tcp" { if c.DisableSocketReusing { panic("Option 'disable-socket-reusing' is forbidden in TCP mode. Socket reusing cannot being enabled, option is useless") diff --git a/internal/configuration/configurationtest/configuration.go b/internal/configuration/configurationtest/configuration.go new file mode 100644 index 0000000..ea440f5 --- /dev/null +++ b/internal/configuration/configurationtest/configuration.go @@ -0,0 +1,133 @@ +package configurationtest + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/netip" + "os" + "path/filepath" + "strings" + "time" + + "github.com/ajabep/unmtlsproxy/internal/configuration" +) + +func GenerateCertificate(certificatePath string, keyPath string) error { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + + pvBytes, err := x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + + // Generate a pem block with the private key + keyPem := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: pvBytes, + }) + + tml := x509.Certificate{ + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(5, 0, 0), + SerialNumber: big.NewInt(123123), + Subject: pkix.Name{ + CommonName: "Hugging Department", + Organization: []string{"BlaHaj Corp."}, + }, + PublicKeyAlgorithm: x509.ECDSA, + SignatureAlgorithm: x509.ECDSAWithSHA256, + } + + cert, err := x509.CreateCertificate(rand.Reader, &tml, &tml, &key.PublicKey, key) + if err != nil { + return err + } + + // Generate a pem block with the certificate + certPem := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert, + }) + + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + certificateOut, err := os.OpenFile(certificatePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + + if _, err := keyOut.Write(keyPem); err != nil { + return err + } + if _, err := certificateOut.Write(certPem); err != nil { + return err + } + return nil +} + +func SetupConfigurationEnv(args map[string]string) { + for k, v := range args { + k = strings.TrimPrefix(k, "--") + k = strings.ToUpper(k) + k = "UNMTLSPROXY_" + k + k = strings.ReplaceAll(k, " ", "_") + k = strings.ReplaceAll(k, "-", "_") + + os.Setenv(k, v) + } +} + +func LoadNewConfiguration(args map[string]string) *configuration.Configuration { + SetupConfigurationEnv(args) + return configuration.NewConfiguration() +} + +func GetExampleDir(level int) (string, error) { + currentDir, err := os.Getwd() // os.Executable() + if err != nil { + return "", err + } + currentDir, err = filepath.Abs(currentDir) + if err != nil { + return "", err + } + for range level { + currentDir = filepath.Dir(currentDir) + } + return filepath.Join(currentDir, "example"), nil +} + +func NewListener() (string, *netip.Addr, uint16, error) { + var minPort, maxPort uint16 = 5000, 65535 + addr, err := netip.ParseAddr("127.0.0.1") + if err != nil { + return "", nil, 0, err + } + for { + x, err := rand.Int(rand.Reader, big.NewInt(int64(maxPort-minPort))) + if err != nil { + return "", nil, 0, err + } + port := uint16(x.Uint64()) + port += minPort + + l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addr.String(), port)) + if err != nil { + continue + } + l.Close() + return fmt.Sprintf("%s:%d", addr.String(), port), &addr, uint16(port), nil + } +} diff --git a/internal/configuration/configurationtest/configuration_test.go b/internal/configuration/configurationtest/configuration_test.go new file mode 100644 index 0000000..b273a54 --- /dev/null +++ b/internal/configuration/configurationtest/configuration_test.go @@ -0,0 +1,39 @@ +package configurationtest + +import ( + "path/filepath" + "testing" +) + +func TestNewConfigurationValidMinimalist(t *testing.T) { + exampleDir, err := GetExampleDir(3) + if err != nil { + panic(err) + } + + config := map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + } + + _ = LoadNewConfiguration(config) +} + +// func TestNewConfigurationValidClientCertificatePassword(t *testing.T) { +// exampleDir, err := GetExampleDir(3) +// if err != nil { +// panic(err) +// } +// +// config := map[string]string{ +// "backend": "https://client.badssl.com", +// "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), +// "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), +// "cert-key-pass": "badssl.com", +// "mode": "http", +// } +// +// _ = LoadNewConfiguration(config) +// } diff --git a/tests/CertificateGenerator.go b/tests/CertificateGenerator.go new file mode 100644 index 0000000..9f266a6 --- /dev/null +++ b/tests/CertificateGenerator.go @@ -0,0 +1,76 @@ +package tests + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "io" + "math/big" + "net" + "time" +) + +func GenerateCertificate(clientAuth bool, certOut, privOut io.Writer) ([]byte, []byte, error) { + privDerBytes, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(120) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, err + } + + templateCert := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Unit Test. DO NOT USE."}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: 0, + ExtKeyUsage: []x509.ExtKeyUsage{}, + BasicConstraintsValid: true, + + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, + } + + templateCert.KeyUsage |= x509.KeyUsageDigitalSignature + if clientAuth { + templateCert.ExtKeyUsage = append(templateCert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) + } else { + templateCert.ExtKeyUsage = append(templateCert.ExtKeyUsage, x509.ExtKeyUsageServerAuth) + templateCert.IsCA = true + templateCert.KeyUsage |= x509.KeyUsageCertSign + } + + certDerBytes, err := x509.CreateCertificate(rand.Reader, &templateCert, &templateCert, privDerBytes.Public(), privDerBytes) + if err != nil { + return nil, nil, err + } + + certPemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDerBytes}) + if _, err := certOut.Write(certPemBytes); err != nil { + return nil, nil, err + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(privDerBytes) + if err != nil { + return nil, nil, err + } + + privPemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + if _, err := privOut.Write(privPemBytes); err != nil { + return nil, nil, err + } + + return certPemBytes, privPemBytes, nil +} diff --git a/tests/MainSupervisor.go b/tests/MainSupervisor.go new file mode 100644 index 0000000..fa63206 --- /dev/null +++ b/tests/MainSupervisor.go @@ -0,0 +1,100 @@ +package tests + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "os/exec" + "testing" + "time" + + "github.com/ajabep/unmtlsproxy/internal/configuration/configurationtest" +) + +type MainFunc func() + +type MainSupervisor struct { + envName string + testName string + main MainFunc + cmd *exec.Cmd +} + +// Will init a new supervisor to execute the main function without crashing the current program. +// It HAVE to be called in at the very start of the test! +func NewMainSupervisor(t *testing.T, main MainFunc) *MainSupervisor { + supervisor := &MainSupervisor{ + envName: "TESTING_EXEC_MAIN" + t.Name(), + testName: t.Name(), + main: main, + } + if encArgs, has := os.LookupEnv(supervisor.envName); has { + jsonArgs, err := base64.RawStdEncoding.DecodeString(encArgs) + if err != nil { + panic(err) + } + + var config map[string]string + err = json.Unmarshal(jsonArgs, &config) + if err != nil { + panic(err) + } + + configurationtest.SetupConfigurationEnv(config) + supervisor.main() + os.Exit(255) // Should never happen, but, just in case of + } + return supervisor +} + +func (m *MainSupervisor) Run(config map[string]string) (string, bool, error) { + var err error + if m.cmd != nil { + m.Close() + } + + addr, has := config["listen"] + if !has { + addr, _, _, err = configurationtest.NewListener() + if err != nil { + return "", false, err + } + config["listen"] = addr + } + + mainStarted := make(chan struct{}, 1) + mainStopped := make(chan struct{}, 2) + + go func(config map[string]string, mainStarted, mainStopped chan<- struct{}) { + jsonArgs, err := json.Marshal(config) + if err != nil { + panic(err) + } + rawArgs := base64.RawStdEncoding.EncodeToString(jsonArgs) + + m.cmd = exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", m.testName)) + m.cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", m.envName, rawArgs)) + + close(mainStarted) + m.cmd.Run() + close(mainStopped) + }(config, mainStarted, mainStopped) + + <-mainStarted + mainHasReturned := false + select { + case <-mainStopped: + mainHasReturned = true + case <-time.After(500 * time.Millisecond): + } + + return addr, mainHasReturned, nil +} +func (m *MainSupervisor) Close() { + if m.cmd != nil && m.cmd.Process != nil { + m.cmd.Process.Kill() + m.cmd.Process.Wait() + } + m.cmd = nil +} diff --git a/tests/SyncedUint.go b/tests/SyncedUint.go new file mode 100644 index 0000000..8d48ef0 --- /dev/null +++ b/tests/SyncedUint.go @@ -0,0 +1,28 @@ +package tests + +import "sync" + +type SyncedUint struct { + m sync.Mutex + v uint +} + +func NewSyncedUintFrom(val uint) *SyncedUint { + s := SyncedUint{ + m: sync.Mutex{}, + v: val, + } + return &s +} + +func NewSyncedUint() *SyncedUint { + return NewSyncedUintFrom(0) +} + +func (s *SyncedUint) GetInc() uint { + s.m.Lock() + defer s.m.Unlock() + oldVal := s.v + s.v += 1 + return oldVal +} diff --git a/tests/TlsServerCounter.go b/tests/TlsServerCounter.go new file mode 100644 index 0000000..fb1e317 --- /dev/null +++ b/tests/TlsServerCounter.go @@ -0,0 +1,238 @@ +package tests + +import ( + "bufio" + "bytes" + "crypto/tls" + "fmt" + "io" + "net" + "net/netip" + "os" + "strings" + + "github.com/ajabep/unmtlsproxy/internal/configuration/configurationtest" +) + +type TlsServerCounter struct { + CertServerFilePath string + KeyServerFilePath string + CertClientFilePath string + KeyClientFilePath string + + ClientKeyPair tls.Certificate + ServerKeyPair tls.Certificate + + addr *netip.Addr + port uint16 + + tlsConfig *tls.Config + + httpMode bool +} + +const ( + httpRefLine string = "GET / HTTP/" + tcpRefLine = "R" +) + +/** + * Creates and a start a TLS server returning the "index" of the request as a body response. + * It also generate a random and temporary certs and keys. + */ +func NewStartedTlsServerCounter(httpMode bool) (*TlsServerCounter, error) { + srv := TlsServerCounter{ + httpMode: httpMode, + } + + certServerFile, err := os.CreateTemp("", "unmtlsproxy_unit_tests_cert_server_*") + if err != nil { + return nil, err + } + privServerFile, err := os.CreateTemp("", "unmtlsproxy_unit_tests_priv_server_*") + if err != nil { + return nil, err + } + certClientFile, err := os.CreateTemp("", "unmtlsproxy_unit_tests_cert_client_*") + if err != nil { + return nil, err + } + privClientFile, err := os.CreateTemp("", "unmtlsproxy_unit_tests_priv_client_*") + if err != nil { + return nil, err + } + + certServer, privServer, err := GenerateCertificate(false, certServerFile, privServerFile) + if err != nil { + return nil, err + } + certClient, privClient, err := GenerateCertificate(true, certClientFile, privClientFile) + if err != nil { + return nil, err + } + + if err := certServerFile.Close(); err != nil { + return nil, err + } + if err := privServerFile.Close(); err != nil { + return nil, err + } + if err := certClientFile.Close(); err != nil { + return nil, err + } + if err := privClientFile.Close(); err != nil { + return nil, err + } + srv.CertServerFilePath = certServerFile.Name() + srv.KeyServerFilePath = privServerFile.Name() + srv.CertClientFilePath = certClientFile.Name() + srv.KeyClientFilePath = privClientFile.Name() + + srv.ServerKeyPair, err = tls.X509KeyPair(certServer, privServer) + if err != nil { + return nil, err + } + srv.ClientKeyPair, err = tls.X509KeyPair(certClient, privClient) + if err != nil { + return nil, err + } + + _, srv.addr, srv.port, err = configurationtest.NewListener() + if err != nil { + return nil, err + } + + srv.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{srv.ServerKeyPair}, + ClientAuth: tls.RequireAnyClientCert, + InsecureSkipVerify: true, + ClientSessionCache: tls.NewLRUClientSessionCache(100), + } + listenerTcp, err := net.ListenTCP( + "tcp4", + srv.TcpAddr(), + ) + if err != nil { + return nil, err + } + listenerTls := tls.NewListener( + listenerTcp, + srv.tlsConfig, + ) + + go func(listener net.Listener) { + // Accept TCP conn + for { + conn, err := listener.Accept() + if err != nil { + panic(err) + } + + tlsConn, ok := conn.(*tls.Conn) + if ok { + if err := tlsConn.Handshake(); err != nil { + // If the handshake failed due to the client not speaking + // TLS, assume they're speaking plaintext HTTP and write a + // 400 response on the TLS conn's underlying net.Conn. + var reason string + re, ok := err.(tls.RecordHeaderError) + if ok && re.Conn != nil && bytes.Equal(re.RecordHeader[:5], []byte("GET /")) { + io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n") + re.Conn.Close() + reason = "client sent an HTTP request to an HTTPS server" + } else { + reason = err.Error() + } + io.WriteString(re.Conn, fmt.Sprintf("HTTP/1.0 400 Bad Request\r\n\r\nhttp: TLS handshake error from %s: %v\n", tlsConn.RemoteAddr(), reason)) + fmt.Printf("http: TLS handshake error from %s: %v", tlsConn.RemoteAddr(), reason) + return + } + } + + go func(conn net.Conn) { + defer conn.Close() + + // For each TCP conn + nbReq := NewSyncedUint() + + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + refLine := tcpRefLine + if srv.httpMode { + refLine = httpRefLine + } + if strings.Contains(line, refLine) { + reqNum := nbReq.GetInc() + var response string + if srv.httpMode { + response = srv.forgeHttpResponse(reqNum) + } else { + response = srv.forgeTcpResponse(reqNum % 10) + } + if _, err := conn.Write([]byte(response)); err != nil { + panic(err) + } + } + } + + if err := scanner.Err(); err != nil { + if e2, ok := err.(*net.OpError); ok && e2.Op == "read" { + return + } + panic(err) + } + }(conn) + } + }(listenerTls) + + return &srv, nil +} + +func (srv *TlsServerCounter) AddrPort() netip.AddrPort { + return netip.AddrPortFrom( + *srv.addr, + srv.port, + ) +} + +func (srv *TlsServerCounter) TcpAddr() *net.TCPAddr { + return net.TCPAddrFromAddrPort(srv.AddrPort()) +} + +func (srv *TlsServerCounter) AddrString() string { + return fmt.Sprintf("%s:%d", srv.addr, srv.port) +} + +func (srv *TlsServerCounter) Backend() string { + if srv.httpMode { + return fmt.Sprintf("https://%s:%d", srv.addr, srv.port) + } + return fmt.Sprintf("%s:%d", srv.addr, srv.port) +} + +func (srv *TlsServerCounter) Mode() string { + if srv.httpMode { + return "http" + } + return "tcp" +} + +func (srv *TlsServerCounter) forgeHttpResponse(id uint) string { + respBody := fmt.Sprintf("%d", id) + response := fmt.Sprintf( + `HTTP/1.1 200 OK +Content-Length: %d +Content-Type: text/plain; utf-8 + +%s`, + len(respBody), + respBody, + ) + return strings.Replace(response, "\n", "\r\n", -1) +} + +func (srv *TlsServerCounter) forgeTcpResponse(id uint) string { + return fmt.Sprintf("%d\n", id) +} diff --git a/unmtlsproxy_test.go b/unmtlsproxy_test.go new file mode 100644 index 0000000..fdae6e5 --- /dev/null +++ b/unmtlsproxy_test.go @@ -0,0 +1,570 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "path/filepath" + "strings" + "testing" + + "github.com/ajabep/unmtlsproxy/internal/configuration/configurationtest" + "github.com/ajabep/unmtlsproxy/tests" +) + +type HttpStatus int + +const ( + MainShouldFail HttpStatus = -1 +) + +type Constraint int + +const ( + Is Constraint = iota + Contains +) + +type FuncBytesTesting func([]byte, []byte) bool + +type TestCaseMainType struct { + name string + config map[string]string + expected struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + } +} +type TestCaseHttpDisableSocketReusingType struct { + name string + config map[string]string + noReuse bool +} +type TestCaseTcpSocketReusingDisabledType struct { + name string + config map[string]string + mainShouldFail bool +} + +func TestMainHttp(t *testing.T) { + mainSupervisor := tests.NewMainSupervisor(t, main) + defer mainSupervisor.Close() + exampleDir, err := configurationtest.GetExampleDir(0) + if err != nil { + panic(err) + } + + for _, testcase := range []TestCaseMainType{ + { + name: "Minimal things", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 200, + "body { background: green; }", + Contains, + }, + }, + // TODO: Open this notation in HTTP mode + //{ + // name: "Backend defined with its port", + // config: map[string]string{ + // "backend": "client.badssl.com:443", + // "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + // "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + // "mode": "http", + // }, + // expected: struct { + // status HttpStatus + // bodyValue string + // bodyConstraint Constraint + // }{ + // 200, + // "body { background: green; }", + // Contains, + // }, + //}, + { + name: "Backend defined with its port AND its protocol", + config: map[string]string{ + "backend": "https://client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 200, + "body { background: green; }", + Contains, + }, + }, + { + name: "Backend a wrong client cert", + config: map[string]string{ + "backend": "https://client-cert-missing.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 400, + "No required SSL certificate was sent", + Contains, + }, + }, + // TODO: Open this notation in HTTP mode + // { + // name: "Non valid backend", + // config: map[string]string{ + // "backend": "0.0.0.0:1111", + // "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + // "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + // "mode": "http", + // }, + // expected: struct { + // status HttpStatus + // bodyValue string + // bodyConstraint Constraint + // }{ + // 503, + // "dial tcp 0.0.0.0:443: connectex: No connection could be made because the target machine actively refused it.", + // Is, + // }, + // }, + { + name: "Non existing backend", + config: map[string]string{ + "backend": "https://0.0.0.0", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 503, + "dial tcp 0.0.0.0:443: connectex: No connection could be made because the target machine actively refused it.", + Is, + }, + }, + { + name: "Wrong CA for validating the Server", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + "server-ca": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 503, + "tls: failed to verify certificate: x509: certificate signed by unknown authority", + Is, + }, + }, + { + name: "Wrong listen definition: Null port", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + "listen": "0.0.0.0:0", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong listen definition: Negative port", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + "listen": "0.0.0.0:-1", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong listen definition: only the port", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + "listen": "443", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Client Certificate Path", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": fmt.Sprintf("%d", rand.Int()), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Client Key Path", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": fmt.Sprintf("%d", rand.Int()), + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Good Client Key Password", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), + "cert-key-pass": "badssl.com", + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Client Key Password", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), + "cert-key-pass": fmt.Sprintf("%d", rand.Int()), + "mode": "http", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Mode", + config: map[string]string{ + "backend": "https://client.badssl.com", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": fmt.Sprintf("%d", rand.Int()), + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + } { + t.Logf("Running Test `%s`", testcase.name) + + addr, hasReturned, err := mainSupervisor.Run(testcase.config) + if err != nil { + panic(err) + } + + if testcase.expected.status != MainShouldFail { + if hasReturned { + t.Errorf("The main function has returned and should not returned.") + continue + } + } else { + if !hasReturned { + t.Errorf("The main function has not returned but should returned.") + } + continue + } + + addr = fmt.Sprintf("http://%s", addr) + resp, err := http.Get(addr) + if err != nil { + panic(err) + } + + if resp.StatusCode != int(testcase.expected.status) { + t.Errorf("Wrong Status Code! Had=%d, Expected=%d", resp.StatusCode, testcase.expected.status) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + body = bytes.TrimSpace(body) + + testValue := []byte(strings.TrimSpace(testcase.expected.bodyValue)) + var testFunc FuncBytesTesting + + if testcase.expected.bodyConstraint == Is { + testFunc = bytes.Equal + } else { + testFunc = bytes.Contains + } + + if !testFunc(body, testValue) { + t.Errorf("The body does not pass via the condition! Condition = `%d`; Condition Value = `%s`; Body = `%s`", testcase.expected.bodyConstraint, testcase.expected.bodyValue, body) + } + } +} + +//func TestUnsecureKeyLogPath(t *testing.T) { +// TODO +//} + +func TestHttpDisableSocketReusing(t *testing.T) { + mainSupervisor := tests.NewMainSupervisor(t, main) + defer mainSupervisor.Close() + + srv, err := tests.NewStartedTlsServerCounter(true) + if err != nil { + panic(err) + } + + for _, testcase := range []TestCaseHttpDisableSocketReusingType{ + { + name: "No options", + config: map[string]string{ + "backend": srv.Backend(), + "cert": srv.CertClientFilePath, + "cert-key": srv.KeyClientFilePath, + "mode": srv.Mode(), + }, + noReuse: false, + }, + // TODO + // { + // name: "Option to false", + // config: map[string]string{ + // "backend": srv.Backend(), + // "cert": srv.CertClientFilePath, + // "cert-key": srv.KeyClientFilePath, + // "mode": srv.Mode(), + // "disable-socket-reusing": "false", + // }, + // noReuse: false, + // }, + { + name: "Option to true", + config: map[string]string{ + "backend": srv.Backend(), + "cert": srv.CertClientFilePath, + "cert-key": srv.KeyClientFilePath, + "mode": srv.Mode(), + "disable-socket-reusing": "true", + }, + noReuse: true, + }, + } { + t.Logf("Running Test `%s`", testcase.name) + + addr, hasReturned, err := mainSupervisor.Run(testcase.config) + if err != nil { + panic(err) + } + addr = fmt.Sprintf("http://%s", addr) + + if hasReturned { + t.Errorf("The main function has returned and should not returned.") + continue + } + + for i := 0; i < 10; i++ { + resp, err := http.Get(addr) + if err != nil { + panic(err) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + + var msg string + var expected string + if testcase.noReuse { + expected = "0" + msg = "The proxy seems to re-use the socket despite the flag" + } else { + expected = fmt.Sprintf("%d", i) + msg = "The proxy seems to re-use the socket despite the flag" + } + if !bytes.Equal(body, []byte(expected)) { + t.Errorf("%s! Expected: %s; Body: %s", msg, expected, body) + } + } + } +} + +func TestTcpIsSocketReusingDisabled(t *testing.T) { + mainSupervisor := tests.NewMainSupervisor(t, main) + defer mainSupervisor.Close() + + srv, err := tests.NewStartedTlsServerCounter(false) + if err != nil { + panic(err) + } + + for _, testcase := range []TestCaseTcpSocketReusingDisabledType{ + { + name: "No options", + config: map[string]string{ + "backend": srv.Backend(), + "cert": srv.CertClientFilePath, + "cert-key": srv.KeyClientFilePath, + "mode": srv.Mode(), + }, + mainShouldFail: false, + }, + // TODO + // { + // name: "Option to false", + // config: map[string]string{ + // "backend": srv.Backend(), + // "cert": srv.CertClientFilePath, + // "cert-key": srv.KeyClientFilePath, + // "mode": srv.Mode(), + // "disable-socket-reusing": "false", + // }, + // mainShouldFail: true, + // }, + { + name: "Option to true", + config: map[string]string{ + "backend": srv.Backend(), + "cert": srv.CertClientFilePath, + "cert-key": srv.KeyClientFilePath, + "mode": srv.Mode(), + "disable-socket-reusing": "true", + }, + mainShouldFail: true, + }, + } { + t.Logf("Running Test `%s`", testcase.name) + + addr, hasReturned, err := mainSupervisor.Run(testcase.config) + if err != nil { + panic(err) + } + + if testcase.mainShouldFail { + if !hasReturned { + t.Errorf("The main function has not returned but should returned.") + } + continue + } + if hasReturned { + t.Errorf("The main function has returned and should not returned.") + continue + } + + for i := 0; i < 10; i++ { + conn, err := net.Dial("tcp", addr) + if err != nil { + panic(err) + } + defer conn.Close() + byteSent := make([]byte, 2) + for j := 0; j < 10; j++ { + conn.Write([]byte("R\n")) + + _, err := conn.Read(byteSent) + if err != nil { + panic(err) + } + + var msg string + expected := fmt.Sprintf("%d\n", j%10) + msg = "The proxy seems to re-use the socket despite the TCP mode" + if !bytes.Equal(byteSent, []byte(expected)) { + t.Errorf("%s! Expected: %s; Body: %s; Requested number: %d", msg, expected, byteSent, j) + } + } + } + } +} + +// Not allow HTTP (no SSL) backend in the HTTPS mode! It's completely silly! +// TODO Same for tcp! From 814bb1ba554f7fa39f395ae2b6083ee324f22f4f Mon Sep 17 00:00:00 2001 From: Ajabep Date: Thu, 29 Aug 2024 22:04:50 +0200 Subject: [PATCH 3/6] Add unit tests for TCP --- unmtlsproxy_test.go | 341 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) diff --git a/unmtlsproxy_test.go b/unmtlsproxy_test.go index fdae6e5..6dfbc78 100644 --- a/unmtlsproxy_test.go +++ b/unmtlsproxy_test.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "bytes" "fmt" "io" @@ -389,6 +390,346 @@ func TestMainHttp(t *testing.T) { } } +func TestMainTcp(t *testing.T) { + mainSupervisor := tests.NewMainSupervisor(t, main) + defer mainSupervisor.Close() + exampleDir, err := configurationtest.GetExampleDir(0) + if err != nil { + panic(err) + } + + for _, testcase := range []TestCaseMainType{ + { + name: "Minimal things", + config: map[string]string{ + "backend": "client.badssl.com:433", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 200, + "body { background: green; }", + Contains, + }, + }, + { + name: "Backend defined with its protocol", + config: map[string]string{ + "backend": "https://client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Contains, + }, + }, + { + name: "Backend defined with its port AND its protocol", + config: map[string]string{ + "backend": "https://client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Contains, + }, + }, + { + name: "Backend a wrong client cert", + config: map[string]string{ + "backend": "client-cert-missing.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 400, + "No required SSL certificate was sent", + Contains, + }, + }, + { + name: "Non existing backend", + config: map[string]string{ + "backend": "0.0.0.0:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 503, + "dial tcp 0.0.0.0:443: connectex: No connection could be made because the target machine actively refused it.", + Is, + }, + }, + { + name: "Wrong CA for validating the Server", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + "server-ca": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + 503, + "tls: failed to verify certificate: x509: certificate signed by unknown authority", + Is, + }, + }, + { + name: "Wrong listen definition: Null port", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + "listen": "0.0.0.0:0", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong listen definition: Negative port", + config: map[string]string{ + "backend": "client.badssl.com:tcp", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + "listen": "0.0.0.0:-1", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong listen definition: only the port", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + "listen": "443", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Client Certificate Path", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": fmt.Sprintf("%d", rand.Int()), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Client Key Path", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": fmt.Sprintf("%d", rand.Int()), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Good Client Key Password", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), + "cert-key-pass": "badssl.com", + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Client Key Password", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), + "cert-key-pass": fmt.Sprintf("%d", rand.Int()), + "mode": "tcp", + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + { + name: "Wrong Mode", + config: map[string]string{ + "backend": "client.badssl.com:443", + "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), + "mode": fmt.Sprintf("%d", rand.Int()), + }, + expected: struct { + status HttpStatus + bodyValue string + bodyConstraint Constraint + }{ + MainShouldFail, + "", + Is, + }, + }, + } { + t.Logf("Running Test `%s`", testcase.name) + + addr, hasReturned, err := mainSupervisor.Run(testcase.config) + if err != nil { + panic(err) + } + + if testcase.expected.status != MainShouldFail { + if hasReturned { + t.Errorf("The main function has returned and should not returned.") + continue + } + } else { + if !hasReturned { + t.Errorf("The main function has not returned but should returned.") + } + continue + } + + conn, err := net.Dial("tcp", addr) + if err != nil { + panic(err) + } + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + panic(err) + } + hostname := "example.com" + if v, has := testcase.config["backend"]; has { + hostname = v + } + req.Header.Add("Host", hostname) + req.Header.Add("Connection", "close") + err = req.Write(conn) + if err != nil { + panic(err) + } + + connReader := bufio.NewReader(conn) + + resp, err := http.ReadResponse(connReader, req) + if err != nil { + panic(err) + } + + if resp.StatusCode != int(testcase.expected.status) { + t.Errorf("Wrong Status Code! Had=%d, Expected=%d", resp.StatusCode, testcase.expected.status) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + body = bytes.TrimSpace(body) + + testValue := []byte(strings.TrimSpace(testcase.expected.bodyValue)) + var testFunc FuncBytesTesting + + if testcase.expected.bodyConstraint == Is { + testFunc = bytes.Equal + } else { + testFunc = bytes.Contains + } + + if !testFunc(body, testValue) { + t.Errorf("The body does not pass via the condition! Condition = `%d`; Condition Value = `%s`; Body = `%s`", testcase.expected.bodyConstraint, testcase.expected.bodyValue, body) + } + } +} + //func TestUnsecureKeyLogPath(t *testing.T) { // TODO //} From 0cafd063f685d10e0f2df0d5fd8671df49a9d670 Mon Sep 17 00:00:00 2001 From: Ajabep Date: Fri, 30 Aug 2024 15:14:06 +0200 Subject: [PATCH 4/6] Typo --- internal/tcpproxy/proxy.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/tcpproxy/proxy.go b/internal/tcpproxy/proxy.go index 223116f..061793b 100644 --- a/internal/tcpproxy/proxy.go +++ b/internal/tcpproxy/proxy.go @@ -88,12 +88,12 @@ func (p *proxy) copy(ctx context.Context, cancel context.CancelFunc, from, to ne default: for { - n, err = to.Read(buffer) + n, err = from.Read(buffer) if err != nil { return } - _, err = from.Write(buffer[:n]) + _, err = to.Write(buffer[:n]) if err != nil { return } From d5ae72998f01c897441a51d187983c7620d43767 Mon Sep 17 00:00:00 2001 From: Ajabep Date: Sat, 7 Sep 2024 20:01:27 +0200 Subject: [PATCH 5/6] Fix tests and add some other --- internal/httpproxy/proxy.go | 2 +- internal/tcpproxy/proxy.go | 28 +++-- tests/TlsServerCounter.go | 2 +- unmtlsproxy_test.go | 212 +++++++++++++++++++++++------------- 4 files changed, 158 insertions(+), 86 deletions(-) diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index 840f920..28394e1 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -105,7 +105,7 @@ func makeHandleHTTP(dest string, tlsConfig *tls.Config, reuseSockets bool) func( return } - defer resp.Body.Close() // nolint: errcheck + defer resp.Body.Close() for k, vv := range resp.Header { for _, v := range vv { w.Header().Add(k, v) diff --git a/internal/tcpproxy/proxy.go b/internal/tcpproxy/proxy.go index 061793b..476464d 100644 --- a/internal/tcpproxy/proxy.go +++ b/internal/tcpproxy/proxy.go @@ -16,6 +16,7 @@ import ( "crypto/tls" "log" "net" + "net/url" "os" "os/signal" @@ -38,12 +39,11 @@ func newProxy(from, to string, tlsConfig *tls.Config) *proxy { // Start the proxy. Is blocking! func (p *proxy) start(ctx context.Context) error { - listener, err := net.Listen("tcp", p.from) if err != nil { return err } - defer listener.Close() // nolint + defer listener.Close() for { select { @@ -60,13 +60,14 @@ func (p *proxy) start(ctx context.Context) error { } func (p *proxy) handle(ctx context.Context, connection net.Conn) { + defer connection.Close() - defer connection.Close() // nolint remote, err := tls.Dial("tcp", p.to, p.tlsConfig) if err != nil { + connection.Write([]byte(err.Error())) return } - defer remote.Close() // nolint + defer remote.Close() subctx, cancel := context.WithCancel(ctx) go p.copy(subctx, cancel, remote, connection) @@ -76,7 +77,6 @@ func (p *proxy) handle(ctx context.Context, connection net.Conn) { } func (p *proxy) copy(ctx context.Context, cancel context.CancelFunc, from, to net.Conn) { - defer cancel() var n int @@ -86,7 +86,6 @@ func (p *proxy) copy(ctx context.Context, cancel context.CancelFunc, from, to ne select { default: - for { n, err = from.Read(buffer) if err != nil { @@ -106,10 +105,25 @@ func (p *proxy) copy(ctx context.Context, cancel context.CancelFunc, from, to ne // Start starts the proxy func Start(cfg *configuration.Configuration, tlsConfig *tls.Config) { - ctx, cancel := context.WithCancel(context.Background()) defer cancel() + parsed, err := url.Parse("tcp://" + cfg.Backend) + if err != nil { + panic(err) + } + if parsed.Host != cfg.Backend { + panic("The backend definition seems to be invalid!") + } + + parsed, err = url.Parse("tcp://" + cfg.ListenAddress) + if err != nil { + panic(err) + } + if parsed.Host != cfg.ListenAddress { + panic("The backend definition seems to be invalid!") + } + go func() { if err := newProxy(cfg.ListenAddress, cfg.Backend, tlsConfig).start(ctx); err != nil { log.Fatalln("Unable to start proxy:", err) diff --git a/tests/TlsServerCounter.go b/tests/TlsServerCounter.go index fb1e317..f2bdb70 100644 --- a/tests/TlsServerCounter.go +++ b/tests/TlsServerCounter.go @@ -224,7 +224,7 @@ func (srv *TlsServerCounter) forgeHttpResponse(id uint) string { response := fmt.Sprintf( `HTTP/1.1 200 OK Content-Length: %d -Content-Type: text/plain; utf-8 +Content-Type: text/plain; charset=utf-8 %s`, len(respBody), diff --git a/unmtlsproxy_test.go b/unmtlsproxy_test.go index 6dfbc78..355a984 100644 --- a/unmtlsproxy_test.go +++ b/unmtlsproxy_test.go @@ -3,14 +3,17 @@ package main import ( "bufio" "bytes" + "errors" "fmt" "io" "math/rand" "net" "net/http" + "os" "path/filepath" "strings" "testing" + "time" "github.com/ajabep/unmtlsproxy/internal/configuration/configurationtest" "github.com/ajabep/unmtlsproxy/tests" @@ -19,7 +22,8 @@ import ( type HttpStatus int const ( - MainShouldFail HttpStatus = -1 + MainShouldFail HttpStatus = -1 + NotHttpExpected HttpStatus = -2 ) type Constraint int @@ -56,7 +60,8 @@ func TestMainHttp(t *testing.T) { defer mainSupervisor.Close() exampleDir, err := configurationtest.GetExampleDir(0) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + return } for _, testcase := range []TestCaseMainType{ @@ -166,8 +171,8 @@ func TestMainHttp(t *testing.T) { bodyConstraint Constraint }{ 503, - "dial tcp 0.0.0.0:443: connectex: No connection could be made because the target machine actively refused it.", - Is, + "dial tcp 0.0.0.0:443: connect", + Contains, }, }, { @@ -343,7 +348,8 @@ func TestMainHttp(t *testing.T) { addr, hasReturned, err := mainSupervisor.Run(testcase.config) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } if testcase.expected.status != MainShouldFail { @@ -361,7 +367,8 @@ func TestMainHttp(t *testing.T) { addr = fmt.Sprintf("http://%s", addr) resp, err := http.Get(addr) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } if resp.StatusCode != int(testcase.expected.status) { @@ -371,7 +378,8 @@ func TestMainHttp(t *testing.T) { defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } body = bytes.TrimSpace(body) @@ -395,14 +403,15 @@ func TestMainTcp(t *testing.T) { defer mainSupervisor.Close() exampleDir, err := configurationtest.GetExampleDir(0) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + return } for _, testcase := range []TestCaseMainType{ { name: "Minimal things", config: map[string]string{ - "backend": "client.badssl.com:433", + "backend": "client.badssl.com:443", "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), "mode": "tcp", @@ -420,7 +429,7 @@ func TestMainTcp(t *testing.T) { { name: "Backend defined with its protocol", config: map[string]string{ - "backend": "https://client.badssl.com:443", + "backend": "https://client.badssl.com", "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), "cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"), "mode": "tcp", @@ -484,9 +493,9 @@ func TestMainTcp(t *testing.T) { bodyValue string bodyConstraint Constraint }{ - 503, - "dial tcp 0.0.0.0:443: connectex: No connection could be made because the target machine actively refused it.", - Is, + NotHttpExpected, + "dial tcp 0.0.0.0:443: connect", + Contains, }, }, { @@ -503,7 +512,7 @@ func TestMainTcp(t *testing.T) { bodyValue string bodyConstraint Constraint }{ - 503, + NotHttpExpected, "tls: failed to verify certificate: x509: certificate signed by unknown authority", Is, }, @@ -658,75 +667,117 @@ func TestMainTcp(t *testing.T) { }, }, } { - t.Logf("Running Test `%s`", testcase.name) + testingFnc := func(testcase TestCaseMainType) { + t.Logf("Running Test `%s`", testcase.name) - addr, hasReturned, err := mainSupervisor.Run(testcase.config) - if err != nil { - panic(err) - } + addr, hasReturned, err := mainSupervisor.Run(testcase.config) + if err != nil { + t.Errorf("Unexpected Error: %#v", err) + return + } + defer mainSupervisor.Close() - if testcase.expected.status != MainShouldFail { - if hasReturned { - t.Errorf("The main function has returned and should not returned.") - continue + if testcase.expected.status != MainShouldFail { + if hasReturned { + t.Errorf("The main function has returned and should not returned.") + return + } + } else { + if !hasReturned { + t.Errorf("The main function has not returned but should returned.") + } + return } - } else { - if !hasReturned { - t.Errorf("The main function has not returned but should returned.") + + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Errorf("Unexpected Error: %#v", err) + return } - continue - } + defer conn.Close() - conn, err := net.Dial("tcp", addr) - if err != nil { - panic(err) - } + connReader := bufio.NewReader(conn) - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - panic(err) - } - hostname := "example.com" - if v, has := testcase.config["backend"]; has { - hostname = v - } - req.Header.Add("Host", hostname) - req.Header.Add("Connection", "close") - err = req.Write(conn) - if err != nil { - panic(err) - } + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + body, err := io.ReadAll(connReader) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("Unexpected Error: %#v", err) + return + } + } - connReader := bufio.NewReader(conn) + statusCode := int(NotHttpExpected) - resp, err := http.ReadResponse(connReader, req) - if err != nil { - panic(err) - } + if len(body) == 0 { + // No error have been raised when opening the socket - if resp.StatusCode != int(testcase.expected.status) { - t.Errorf("Wrong Status Code! Had=%d, Expected=%d", resp.StatusCode, testcase.expected.status) - } + hostname := "example.com" + if v, has := testcase.config["backend"]; has { + hostname = v + } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - panic(err) - } - body = bytes.TrimSpace(body) + req, err := http.NewRequest("GET", fmt.Sprintf("https://%s/", hostname), nil) + if err != nil { + t.Errorf("Unexpected Error: %#v", err) + } + hostname, _ = strings.CutSuffix(hostname, ":443") + req.Header.Add("Host", hostname) + req.Header.Add("Connection", "close") + req.Header.Add("Content-Length", "0") + err = req.Write(conn) + if err != nil { + t.Errorf("Unexpected Error: %#v", err) + return + } - testValue := []byte(strings.TrimSpace(testcase.expected.bodyValue)) - var testFunc FuncBytesTesting + conn.SetReadDeadline(time.Time{}) + body, err = io.ReadAll(connReader) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("Unexpected Error: %#v", err) + return + } + } - if testcase.expected.bodyConstraint == Is { - testFunc = bytes.Equal - } else { - testFunc = bytes.Contains - } + // Now, we have the conn content. + + // Let's see if it's a HTTP Response! + bodyReader := bufio.NewReader(bytes.NewReader(body)) + + resp, err := http.ReadResponse(bodyReader, req) + if err == nil { + // It's an HTTP response, so, changing the variables we will compare + statusCode = resp.StatusCode + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Unexpected Error: %#v", err) + return + } + } + } - if !testFunc(body, testValue) { - t.Errorf("The body does not pass via the condition! Condition = `%d`; Condition Value = `%s`; Body = `%s`", testcase.expected.bodyConstraint, testcase.expected.bodyValue, body) + if statusCode != int(testcase.expected.status) { + t.Errorf("Wrong Status Code! Had=%d, Expected=%d", statusCode, testcase.expected.status) + } + + body = bytes.TrimSpace(body) + + testValue := []byte(strings.TrimSpace(testcase.expected.bodyValue)) + var testFunc FuncBytesTesting + + if testcase.expected.bodyConstraint == Is { + testFunc = bytes.Equal + } else { + testFunc = bytes.Contains + } + + if !testFunc(body, testValue) { + t.Errorf("The body does not pass via the condition! Condition = `%d`; Condition Value = `%s`; Body = `%s`", testcase.expected.bodyConstraint, testcase.expected.bodyValue, body) + } } + testingFnc(testcase) } } @@ -740,7 +791,8 @@ func TestHttpDisableSocketReusing(t *testing.T) { srv, err := tests.NewStartedTlsServerCounter(true) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + return } for _, testcase := range []TestCaseHttpDisableSocketReusingType{ @@ -782,7 +834,7 @@ func TestHttpDisableSocketReusing(t *testing.T) { addr, hasReturned, err := mainSupervisor.Run(testcase.config) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) } addr = fmt.Sprintf("http://%s", addr) @@ -794,13 +846,15 @@ func TestHttpDisableSocketReusing(t *testing.T) { for i := 0; i < 10; i++ { resp, err := http.Get(addr) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } var msg string @@ -825,7 +879,8 @@ func TestTcpIsSocketReusingDisabled(t *testing.T) { srv, err := tests.NewStartedTlsServerCounter(false) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + return } for _, testcase := range []TestCaseTcpSocketReusingDisabledType{ @@ -867,7 +922,8 @@ func TestTcpIsSocketReusingDisabled(t *testing.T) { addr, hasReturned, err := mainSupervisor.Run(testcase.config) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } if testcase.mainShouldFail { @@ -884,7 +940,8 @@ func TestTcpIsSocketReusingDisabled(t *testing.T) { for i := 0; i < 10; i++ { conn, err := net.Dial("tcp", addr) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } defer conn.Close() byteSent := make([]byte, 2) @@ -893,7 +950,8 @@ func TestTcpIsSocketReusingDisabled(t *testing.T) { _, err := conn.Read(byteSent) if err != nil { - panic(err) + t.Errorf("Unexpected Error: %#v", err) + continue } var msg string From 56250439882f02548254d64e99b8bdb5954fac97 Mon Sep 17 00:00:00 2001 From: Ajabep Date: Sat, 7 Sep 2024 22:29:17 +0200 Subject: [PATCH 6/6] Fix some tests --- tests/TlsServerCounter.go | 10 ++++ unmtlsproxy_test.go | 99 ++++++++++++++++++++------------------- 2 files changed, 61 insertions(+), 48 deletions(-) diff --git a/tests/TlsServerCounter.go b/tests/TlsServerCounter.go index f2bdb70..05de3f9 100644 --- a/tests/TlsServerCounter.go +++ b/tests/TlsServerCounter.go @@ -10,6 +10,7 @@ import ( "net/netip" "os" "strings" + "time" "github.com/ajabep/unmtlsproxy/internal/configuration/configurationtest" ) @@ -61,6 +62,15 @@ func NewStartedTlsServerCounter(httpMode bool) (*TlsServerCounter, error) { if err != nil { return nil, err } + go func() { + defer os.Remove(certServerFile.Name()) + defer os.Remove(privServerFile.Name()) + defer os.Remove(certClientFile.Name()) + defer os.Remove(privClientFile.Name()) + for { + time.Sleep(999 * time.Hour) + } + }() certServer, privServer, err := GenerateCertificate(false, certServerFile, privServerFile) if err != nil { diff --git a/unmtlsproxy_test.go b/unmtlsproxy_test.go index 355a984..83439c7 100644 --- a/unmtlsproxy_test.go +++ b/unmtlsproxy_test.go @@ -24,6 +24,7 @@ type HttpStatus int const ( MainShouldFail HttpStatus = -1 NotHttpExpected HttpStatus = -2 + NoRequestSent HttpStatus = -3 ) type Constraint int @@ -287,25 +288,26 @@ func TestMainHttp(t *testing.T) { Is, }, }, - { - name: "Good Client Key Password", - config: map[string]string{ - "backend": "https://client.badssl.com", - "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), - "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), - "cert-key-pass": "badssl.com", - "mode": "http", - }, - expected: struct { - status HttpStatus - bodyValue string - bodyConstraint Constraint - }{ - MainShouldFail, - "", - Is, - }, - }, + // Issue #31 + // { + // name: "Correct Client Key Password", + // config: map[string]string{ + // "backend": "https://client.badssl.com", + // "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + // "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), + // "cert-key-pass": "badssl.com", + // "mode": "http", + // }, + // expected: struct { + // status HttpStatus + // bodyValue string + // bodyConstraint Constraint + // }{ + // 200, + // "body { background: green; }", + // Contains, + // }, + // }, { name: "Wrong Client Key Password", config: map[string]string{ @@ -493,7 +495,7 @@ func TestMainTcp(t *testing.T) { bodyValue string bodyConstraint Constraint }{ - NotHttpExpected, + NoRequestSent, "dial tcp 0.0.0.0:443: connect", Contains, }, @@ -512,7 +514,7 @@ func TestMainTcp(t *testing.T) { bodyValue string bodyConstraint Constraint }{ - NotHttpExpected, + NoRequestSent, "tls: failed to verify certificate: x509: certificate signed by unknown authority", Is, }, @@ -610,25 +612,26 @@ func TestMainTcp(t *testing.T) { Is, }, }, - { - name: "Good Client Key Password", - config: map[string]string{ - "backend": "client.badssl.com:443", - "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), - "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), - "cert-key-pass": "badssl.com", - "mode": "tcp", - }, - expected: struct { - status HttpStatus - bodyValue string - bodyConstraint Constraint - }{ - MainShouldFail, - "", - Is, - }, - }, + // Issue #31 + // { + // name: "Correct Client Key Password", + // config: map[string]string{ + // "backend": "client.badssl.com:443", + // "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"), + // "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"), + // "cert-key-pass": "badssl.com", + // "mode": "tcp", + // }, + // expected: struct { + // status HttpStatus + // bodyValue string + // bodyConstraint Constraint + // }{ + // 200, + // "body { background: green; }", + // Contains, + // }, + // }, { name: "Wrong Client Key Password", config: map[string]string{ @@ -707,7 +710,7 @@ func TestMainTcp(t *testing.T) { } } - statusCode := int(NotHttpExpected) + statusCode := int(NoRequestSent) if len(body) == 0 { // No error have been raised when opening the socket @@ -739,6 +742,7 @@ func TestMainTcp(t *testing.T) { return } } + statusCode = int(NotHttpExpected) // Now, we have the conn content. @@ -781,9 +785,9 @@ func TestMainTcp(t *testing.T) { } } -//func TestUnsecureKeyLogPath(t *testing.T) { -// TODO -//} +// TODO Find a way to test that! +// func TestUnsecureKeyLogPath(t *testing.T) { +// } func TestHttpDisableSocketReusing(t *testing.T) { mainSupervisor := tests.NewMainSupervisor(t, main) @@ -797,7 +801,7 @@ func TestHttpDisableSocketReusing(t *testing.T) { for _, testcase := range []TestCaseHttpDisableSocketReusingType{ { - name: "No options", + name: "No option", config: map[string]string{ "backend": srv.Backend(), "cert": srv.CertClientFilePath, @@ -885,7 +889,7 @@ func TestTcpIsSocketReusingDisabled(t *testing.T) { for _, testcase := range []TestCaseTcpSocketReusingDisabledType{ { - name: "No options", + name: "No option", config: map[string]string{ "backend": srv.Backend(), "cert": srv.CertClientFilePath, @@ -965,5 +969,4 @@ func TestTcpIsSocketReusingDisabled(t *testing.T) { } } -// Not allow HTTP (no SSL) backend in the HTTPS mode! It's completely silly! -// TODO Same for tcp! +// TODO Not allow HTTP (no SSL) backend in the HTTPS mode! It's completely silly!