Add tool to check XMPP certificate on starttls
Minor refactoring
This commit is contained in:
parent
67d9170354
commit
d16c4cbba4
2
auth.go
2
auth.go
|
@ -8,7 +8,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
func authSASL(socket io.ReadWriter, decoder *xml.Decoder, f streamFeatures, user string, password string) (err error) {
|
func authSASL(socket io.ReadWriter, decoder *xml.Decoder, f StreamFeatures, user string, password string) (err error) {
|
||||||
// TODO: Implement other type of SASL Authentication
|
// TODO: Implement other type of SASL Authentication
|
||||||
havePlain := false
|
havePlain := false
|
||||||
for _, m := range f.Mechanisms.Mechanism {
|
for _, m := range f.Mechanisms.Mechanism {
|
||||||
|
|
146
check_cert.go
Normal file
146
check_cert.go
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
package xmpp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/xml"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Should I move this as an extension of the client?
|
||||||
|
// I should probably make the code more modular, but keep concern separated to keep it simple.
|
||||||
|
type ServerCheck struct {
|
||||||
|
address string
|
||||||
|
domain string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChecker(address, domain string) (*ServerCheck, error) {
|
||||||
|
client := ServerCheck{}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var host string
|
||||||
|
if client.address, host, err = extractParams(address); err != nil {
|
||||||
|
return &client, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domain != "" {
|
||||||
|
client.domain = domain
|
||||||
|
} else {
|
||||||
|
client.domain = host
|
||||||
|
}
|
||||||
|
|
||||||
|
return &client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check triggers actual TCP connection, based on previously defined parameters.
|
||||||
|
func (c *ServerCheck) Check() error {
|
||||||
|
var tcpconn net.Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
|
timeout := 15 * time.Second
|
||||||
|
tcpconn, err = net.DialTimeout("tcp", c.address, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder := xml.NewDecoder(tcpconn)
|
||||||
|
|
||||||
|
// Send stream open tag
|
||||||
|
if _, err = fmt.Fprintf(tcpconn, xmppStreamOpen, c.domain, NSClient, NSStream); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set xml decoder and extract streamID from reply (not used for now)
|
||||||
|
_, err = initDecoder(decoder)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract stream features
|
||||||
|
var f StreamFeatures
|
||||||
|
packet, err := next(decoder)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("stream open decode features: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p := packet.(type) {
|
||||||
|
case StreamFeatures:
|
||||||
|
f = p
|
||||||
|
case StreamError:
|
||||||
|
return errors.New("open stream error: " + p.Error.Local)
|
||||||
|
default:
|
||||||
|
return errors.New("expected packet received while expecting features, got " + p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local
|
||||||
|
if startTLSFeature == nsTLS+" starttls" {
|
||||||
|
fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
||||||
|
|
||||||
|
var k tlsProceed
|
||||||
|
if err = decoder.DecodeElement(&k, nil); err != nil {
|
||||||
|
return fmt.Errorf("expecting starttls proceed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DefaultTlsConfig.ServerName = c.domain
|
||||||
|
tlsConn := tls.Client(tcpconn, &DefaultTlsConfig)
|
||||||
|
// We convert existing connection to TLS
|
||||||
|
if err = tlsConn.Handshake(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We check that cert matches hostname
|
||||||
|
if err = tlsConn.VerifyHostname(c.domain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = checkExpiration(tlsConn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("TLS not supported on server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiration date for the whole certificate chain and returns an error
|
||||||
|
// if the expiration date is in less than 48 hours.
|
||||||
|
func checkExpiration(tlsConn *tls.Conn) error {
|
||||||
|
checkedCerts := make(map[string]struct{})
|
||||||
|
for _, chain := range tlsConn.ConnectionState().VerifiedChains {
|
||||||
|
for _, cert := range chain {
|
||||||
|
if _, checked := checkedCerts[string(cert.Signature)]; checked {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
checkedCerts[string(cert.Signature)] = struct{}{}
|
||||||
|
|
||||||
|
// Check the expiration.
|
||||||
|
timeNow := time.Now()
|
||||||
|
expiresInHours := int64(cert.NotAfter.Sub(timeNow).Hours())
|
||||||
|
// fmt.Printf("Cert '%s' expires in %d days\n", cert.Subject.CommonName, expiresInHours/24)
|
||||||
|
if expiresInHours <= 48 {
|
||||||
|
return fmt.Errorf("certificate '%s' will expire on %s", cert.Subject.CommonName, cert.NotAfter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractParams(addr string) (string, string, error) {
|
||||||
|
var err error
|
||||||
|
hostport := strings.Split(addr, ":")
|
||||||
|
if len(hostport) > 2 {
|
||||||
|
err = errors.New("too many colons in xmpp server address")
|
||||||
|
return addr, hostport[0], err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address is composed of two parts, we are good
|
||||||
|
if len(hostport) == 2 && hostport[1] != "" {
|
||||||
|
return addr, hostport[0], err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port was not passed, we append XMPP default port:
|
||||||
|
return strings.Join([]string{hostport[0], "5222"}, ":"), hostport[0], err
|
||||||
|
}
|
3
cmd/xmpp-check/TODO.md
Normal file
3
cmd/xmpp-check/TODO.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# TODO
|
||||||
|
|
||||||
|
- Use a config file to define the checks to perform as client on an XMPP server.
|
43
cmd/xmpp-check/xmpp-check.go
Normal file
43
cmd/xmpp-check/xmpp-check.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gosrc.io/xmpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
args := os.Args[1:]
|
||||||
|
|
||||||
|
if len(args) == 0 {
|
||||||
|
log.Fatal("usage: xmpp-check host[:port] [domain]")
|
||||||
|
}
|
||||||
|
|
||||||
|
var address string
|
||||||
|
var domain string
|
||||||
|
if len(args) >= 1 {
|
||||||
|
address = args[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) >= 2 {
|
||||||
|
domain = args[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
runCheck(address, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCheck(address, domain string) {
|
||||||
|
client, err := xmpp.NewChecker(address, domain)
|
||||||
|
// client, err := xmpp.NewChecker("mickael.m.in-app.io:5222", "mickael.m.in-app.io")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = client.Check(); err != nil {
|
||||||
|
log.Fatal("Failed connection check: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("All checks passed")
|
||||||
|
}
|
|
@ -90,6 +90,8 @@ func decodeStream(p *xml.Decoder, se xml.StartElement) (Packet, error) {
|
||||||
switch se.Name.Local {
|
switch se.Name.Local {
|
||||||
case "error":
|
case "error":
|
||||||
return streamError.decode(p, se)
|
return streamError.decode(p, se)
|
||||||
|
case "features":
|
||||||
|
return streamFeatures.decode(p, se)
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unexpected XMPP packet " +
|
return nil, errors.New("unexpected XMPP packet " +
|
||||||
se.Name.Space + " <" + se.Name.Local + "/>")
|
se.Name.Space + " <" + se.Name.Local + "/>")
|
||||||
|
|
|
@ -15,7 +15,7 @@ type Session struct {
|
||||||
// Session info
|
// Session info
|
||||||
BindJid string // Jabber ID as provided by XMPP server
|
BindJid string // Jabber ID as provided by XMPP server
|
||||||
StreamId string
|
StreamId string
|
||||||
Features streamFeatures
|
Features StreamFeatures
|
||||||
TlsEnabled bool
|
TlsEnabled bool
|
||||||
lastPacketId int
|
lastPacketId int
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ func (s *Session) setProxy(conn net.Conn, newConn net.Conn, o Config) {
|
||||||
s.decoder.CharsetReader = o.CharsetReader
|
s.decoder.CharsetReader = o.CharsetReader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) open(domain string) (f streamFeatures) {
|
func (s *Session) open(domain string) (f StreamFeatures) {
|
||||||
// Send stream open tag
|
// Send stream open tag
|
||||||
if _, s.err = fmt.Fprintf(s.socketProxy, xmppStreamOpen, domain, NSClient, NSStream); s.err != nil {
|
if _, s.err = fmt.Fprintf(s.socketProxy, xmppStreamOpen, domain, NSClient, NSStream); s.err != nil {
|
||||||
return
|
return
|
||||||
|
@ -121,7 +121,7 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string) net.Conn {
|
||||||
|
|
||||||
// TODO: add option to accept all TLS certificates: insecureSkipTlsVerify (DefaultTlsConfig.InsecureSkipVerify)
|
// TODO: add option to accept all TLS certificates: insecureSkipTlsVerify (DefaultTlsConfig.InsecureSkipVerify)
|
||||||
DefaultTlsConfig.ServerName = domain
|
DefaultTlsConfig.ServerName = domain
|
||||||
var tlsConn *tls.Conn = tls.Client(conn, &DefaultTlsConfig)
|
tlsConn := tls.Client(conn, &DefaultTlsConfig)
|
||||||
// We convert existing connection to TLS
|
// We convert existing connection to TLS
|
||||||
if s.err = tlsConn.Handshake(); s.err != nil {
|
if s.err = tlsConn.Handshake(); s.err != nil {
|
||||||
return tlsConn
|
return tlsConn
|
||||||
|
|
|
@ -20,21 +20,21 @@ func newSocketProxy(conn io.ReadWriter, logFile *os.File) io.ReadWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pl *socketProxy) Read(p []byte) (n int, err error) {
|
func (sp *socketProxy) Read(p []byte) (n int, err error) {
|
||||||
n, err = pl.socket.Read(p)
|
n, err = sp.socket.Read(p)
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
pl.logFile.Write([]byte("RECV:\n")) // Prefix
|
sp.logFile.Write([]byte("RECV:\n")) // Prefix
|
||||||
if n, err := pl.logFile.Write(p[:n]); err != nil {
|
if n, err := sp.logFile.Write(p[:n]); err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
pl.logFile.Write([]byte("\n\n")) // Separator
|
sp.logFile.Write([]byte("\n\n")) // Separator
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pl *socketProxy) Write(p []byte) (n int, err error) {
|
func (sp *socketProxy) Write(p []byte) (n int, err error) {
|
||||||
pl.logFile.Write([]byte("SEND:\n")) // Prefix
|
sp.logFile.Write([]byte("SEND:\n")) // Prefix
|
||||||
for _, w := range []io.Writer{pl.socket, pl.logFile} {
|
for _, w := range []io.Writer{sp.socket, sp.logFile} {
|
||||||
n, err = w.Write(p)
|
n, err = w.Write(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -44,6 +44,6 @@ func (pl *socketProxy) Write(p []byte) (n int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pl.logFile.Write([]byte("\n\n")) // Separator
|
sp.logFile.Write([]byte("\n\n")) // Separator
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
16
stream.go
16
stream.go
|
@ -7,7 +7,7 @@ import (
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// StreamFeatures Packet
|
// StreamFeatures Packet
|
||||||
|
|
||||||
type streamFeatures struct {
|
type StreamFeatures struct {
|
||||||
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
|
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
|
||||||
StartTLS tlsStartTLS
|
StartTLS tlsStartTLS
|
||||||
Caps Caps
|
Caps Caps
|
||||||
|
@ -17,6 +17,20 @@ type streamFeatures struct {
|
||||||
Any []xml.Name `xml:",any"`
|
Any []xml.Name `xml:",any"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (StreamFeatures) Name() string {
|
||||||
|
return "stream:features"
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamFeatureDecoder struct{}
|
||||||
|
|
||||||
|
var streamFeatures streamFeatureDecoder
|
||||||
|
|
||||||
|
func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) {
|
||||||
|
var packet StreamFeatures
|
||||||
|
err := p.DecodeElement(&packet, &se)
|
||||||
|
return packet, err
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// StreamError Packet
|
// StreamError Packet
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue