Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unit tests and various fixes #36

Merged
merged 6 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ jobs:
with:
go-version: '1.22'

- name: Vet
run: go vet ./...

- name: Build
run: go build -v ./...

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ unmtlsproxy_win_386.exe
unmtlsproxy
unmtlsproxy.exe
local_build.ps1
__debug_bin*
16 changes: 16 additions & 0 deletions internal/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/url"
"os"
"strconv"

"go.aporeto.io/addedeffect/lombric"
"go.aporeto.io/tg/tglib"
Expand Down Expand Up @@ -55,6 +57,20 @@ func NewConfiguration() *Configuration {
c := &Configuration{}
lombric.Initialize(c)

listenUrl, err := url.Parse("http://" + c.ListenAddress)
if err != nil {
panic(err)
}
if port := listenUrl.Port(); port == "" {
panic("Invalid Listen format. Use `hostname:port`.")
} else if portInt, err := strconv.Atoi(port); err != nil {
panic(err)
} else if portInt <= 0 {
panic("Invalid Listening Port. Too low.")
} else if portInt > 65535 {
panic("Invalid Listening Port. Too High. We use TCP.")
}

if c.Mode == "tcp" {
if c.DisableSocketReusing {
panic("Option 'disable-socket-reusing' is forbidden in TCP mode. Socket reusing cannot being enabled, option is useless")
Expand Down
133 changes: 133 additions & 0 deletions internal/configuration/configurationtest/configuration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package configurationtest

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/netip"
"os"
"path/filepath"
"strings"
"time"

"github.com/ajabep/unmtlsproxy/internal/configuration"
)

func GenerateCertificate(certificatePath string, keyPath string) error {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}

pvBytes, err := x509.MarshalECPrivateKey(key)
if err != nil {
return err
}

// Generate a pem block with the private key
keyPem := pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: pvBytes,
})

tml := x509.Certificate{
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(5, 0, 0),
SerialNumber: big.NewInt(123123),
Subject: pkix.Name{
CommonName: "Hugging Department",
Organization: []string{"BlaHaj Corp."},
},
PublicKeyAlgorithm: x509.ECDSA,
SignatureAlgorithm: x509.ECDSAWithSHA256,
}

cert, err := x509.CreateCertificate(rand.Reader, &tml, &tml, &key.PublicKey, key)
if err != nil {
return err
}

// Generate a pem block with the certificate
certPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert,
})

keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
certificateOut, err := os.OpenFile(certificatePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}

if _, err := keyOut.Write(keyPem); err != nil {
return err
}
if _, err := certificateOut.Write(certPem); err != nil {
return err
}
return nil
}

func SetupConfigurationEnv(args map[string]string) {
for k, v := range args {
k = strings.TrimPrefix(k, "--")
k = strings.ToUpper(k)
k = "UNMTLSPROXY_" + k
k = strings.ReplaceAll(k, " ", "_")
k = strings.ReplaceAll(k, "-", "_")

os.Setenv(k, v)
}
}

func LoadNewConfiguration(args map[string]string) *configuration.Configuration {
SetupConfigurationEnv(args)
return configuration.NewConfiguration()
}

func GetExampleDir(level int) (string, error) {
currentDir, err := os.Getwd() // os.Executable()
if err != nil {
return "", err
}
currentDir, err = filepath.Abs(currentDir)
if err != nil {
return "", err
}
for range level {
currentDir = filepath.Dir(currentDir)
}
return filepath.Join(currentDir, "example"), nil
}

func NewListener() (string, *netip.Addr, uint16, error) {
var minPort, maxPort uint16 = 5000, 65535
addr, err := netip.ParseAddr("127.0.0.1")
if err != nil {
return "", nil, 0, err
}
for {
x, err := rand.Int(rand.Reader, big.NewInt(int64(maxPort-minPort)))
if err != nil {
return "", nil, 0, err
}
port := uint16(x.Uint64())
port += minPort

l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addr.String(), port))
if err != nil {
continue
}
l.Close()
return fmt.Sprintf("%s:%d", addr.String(), port), &addr, uint16(port), nil
}
}
39 changes: 39 additions & 0 deletions internal/configuration/configurationtest/configuration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package configurationtest

import (
"path/filepath"
"testing"
)

func TestNewConfigurationValidMinimalist(t *testing.T) {
exampleDir, err := GetExampleDir(3)
if err != nil {
panic(err)
}

config := map[string]string{
"backend": "https://client.badssl.com",
"cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"),
"cert-key": filepath.Join(exampleDir, "badssl.com-client_NOENCRYPTION.key.pem"),
"mode": "http",
}

_ = LoadNewConfiguration(config)
}

// func TestNewConfigurationValidClientCertificatePassword(t *testing.T) {
// exampleDir, err := GetExampleDir(3)
// if err != nil {
// panic(err)
// }
//
// config := map[string]string{
// "backend": "https://client.badssl.com",
// "cert": filepath.Join(exampleDir, "badssl.com-client.crt.pem"),
// "cert-key": filepath.Join(exampleDir, "badssl.com-client.key.pem"),
// "cert-key-pass": "badssl.com",
// "mode": "http",
// }
//
// _ = LoadNewConfiguration(config)
// }
4 changes: 2 additions & 2 deletions internal/httpproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func makeHandleHTTP(dest string, tlsConfig *tls.Config, reuseSockets bool) func(
case "80":
u.Scheme = "http"
case "443":
u.Scheme = "http"
u.Scheme = "https"
case "":
switch u.Scheme {
default:
Expand Down Expand Up @@ -105,7 +105,7 @@ func makeHandleHTTP(dest string, tlsConfig *tls.Config, reuseSockets bool) func(
return
}

defer resp.Body.Close() // nolint: errcheck
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
Expand Down
34 changes: 24 additions & 10 deletions internal/tcpproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"crypto/tls"
"log"
"net"
"net/url"
"os"
"os/signal"

Expand All @@ -36,14 +37,13 @@ func newProxy(from, to string, tlsConfig *tls.Config) *proxy {
}
}

// Start the proxy. Is blocking!
func (p *proxy) start(ctx context.Context) error {
/// Start the proxy. Is blocking!

listener, err := net.Listen("tcp", p.from)
if err != nil {
return err
}
defer listener.Close() // nolint
defer listener.Close()

for {
select {
Expand All @@ -60,13 +60,14 @@ func (p *proxy) start(ctx context.Context) error {
}

func (p *proxy) handle(ctx context.Context, connection net.Conn) {
defer connection.Close()

defer connection.Close() // nolint
remote, err := tls.Dial("tcp", p.to, p.tlsConfig)
if err != nil {
connection.Write([]byte(err.Error()))
return
}
defer remote.Close() // nolint
defer remote.Close()

subctx, cancel := context.WithCancel(ctx)
go p.copy(subctx, cancel, remote, connection)
Expand All @@ -76,7 +77,6 @@ func (p *proxy) handle(ctx context.Context, connection net.Conn) {
}

func (p *proxy) copy(ctx context.Context, cancel context.CancelFunc, from, to net.Conn) {

defer cancel()

var n int
Expand All @@ -86,14 +86,13 @@ func (p *proxy) copy(ctx context.Context, cancel context.CancelFunc, from, to ne
select {

default:

for {
n, err = to.Read(buffer)
n, err = from.Read(buffer)
if err != nil {
return
}

_, err = from.Write(buffer[:n])
_, err = to.Write(buffer[:n])
if err != nil {
return
}
Expand All @@ -106,10 +105,25 @@ func (p *proxy) copy(ctx context.Context, cancel context.CancelFunc, from, to ne

// Start starts the proxy
func Start(cfg *configuration.Configuration, tlsConfig *tls.Config) {

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

parsed, err := url.Parse("tcp://" + cfg.Backend)
if err != nil {
panic(err)
}
if parsed.Host != cfg.Backend {
panic("The backend definition seems to be invalid!")
}

parsed, err = url.Parse("tcp://" + cfg.ListenAddress)
if err != nil {
panic(err)
}
if parsed.Host != cfg.ListenAddress {
panic("The backend definition seems to be invalid!")
}

go func() {
if err := newProxy(cfg.ListenAddress, cfg.Backend, tlsConfig).start(ctx); err != nil {
log.Fatalln("Unable to start proxy:", err)
Expand Down
Loading