282 lines
6.3 KiB
Go
282 lines
6.3 KiB
Go
package xmpp
|
|
|
|
import (
|
|
"github.com/pkg/errors"
|
|
"regexp"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"dev.narayana.im/narayana/telegabber/config"
|
|
"dev.narayana.im/narayana/telegabber/persistence"
|
|
"dev.narayana.im/narayana/telegabber/telegram"
|
|
"dev.narayana.im/narayana/telegabber/xmpp/gateway"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"gosrc.io/xmpp"
|
|
"gosrc.io/xmpp/stanza"
|
|
)
|
|
|
|
var tgConf config.TelegramConfig
|
|
var sessions map[string]*telegram.Client
|
|
var db *persistence.SessionsYamlDB
|
|
var sessionLock sync.Mutex
|
|
|
|
const (
|
|
B uint64 = 1
|
|
KB = B << 10
|
|
MB = KB << 10
|
|
GB = MB << 10
|
|
TB = GB << 10
|
|
PB = TB << 10
|
|
EB = PB << 10
|
|
|
|
maxUint64 uint64 = (1 << 64) - 1
|
|
)
|
|
var sizeRegex = regexp.MustCompile("\\A([0-9]+) ?([KMGTPE]?B?)\\z")
|
|
|
|
// NewComponent starts a new component and wraps it in
|
|
// a stream manager that you should start yourself
|
|
func NewComponent(conf config.XMPPConfig, tc config.TelegramConfig) (*xmpp.StreamManager, *xmpp.Component, error) {
|
|
var err error
|
|
|
|
gateway.Jid, err = stanza.NewJid(conf.Jid)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
tgConf = tc
|
|
|
|
if tc.Content.Quota != "" {
|
|
gateway.StorageQuota, err = parseSize(tc.Content.Quota)
|
|
if err != nil {
|
|
log.Warnf("Error parsing the storage quota: %v; the cleaner is disabled", err)
|
|
}
|
|
}
|
|
|
|
options := xmpp.ComponentOptions{
|
|
TransportConfiguration: xmpp.TransportConfiguration{
|
|
Address: conf.Host + ":" + conf.Port,
|
|
Domain: conf.Jid,
|
|
},
|
|
Domain: conf.Jid,
|
|
Secret: conf.Password,
|
|
Name: "telegabber",
|
|
}
|
|
|
|
router := xmpp.NewRouter()
|
|
router.HandleFunc("iq", HandleIq)
|
|
router.HandleFunc("presence", HandlePresence)
|
|
router.HandleFunc("message", HandleMessage)
|
|
|
|
component, err := xmpp.NewComponent(options, router, func(err error) {
|
|
log.Error(err)
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// probe all known sessions
|
|
err = loadSessions(conf.Db, component)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
sm := xmpp.NewStreamManager(component, func(s xmpp.Sender) {
|
|
go heartbeat(component)
|
|
})
|
|
|
|
return sm, component, nil
|
|
}
|
|
|
|
func heartbeat(component *xmpp.Component) {
|
|
var err error
|
|
probeType := gateway.SPType("probe")
|
|
|
|
sessionLock.Lock()
|
|
for jid := range sessions {
|
|
err = gateway.SendPresence(component, jid, probeType)
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
}
|
|
sessionLock.Unlock()
|
|
|
|
quotaLowThreshold := gateway.StorageQuota / 10 * 9
|
|
|
|
log.Info("Starting heartbeat queue")
|
|
|
|
// status updater thread
|
|
for {
|
|
gateway.StorageLock.Lock()
|
|
if quotaLowThreshold > 0 && tgConf.Content.Path != "" {
|
|
gateway.MeasureStorageSize(tgConf.Content.Path)
|
|
|
|
if gateway.CachedStorageSize > quotaLowThreshold {
|
|
gateway.CleanOldFiles(tgConf.Content.Path, quotaLowThreshold)
|
|
}
|
|
}
|
|
gateway.StorageLock.Unlock()
|
|
|
|
time.Sleep(60e9)
|
|
now := time.Now().Unix()
|
|
|
|
sessionLock.Lock()
|
|
for _, session := range sessions {
|
|
session.DelayedStatusesLock.Lock()
|
|
for chatID, delayedStatus := range session.DelayedStatuses {
|
|
if delayedStatus.TimestampExpired <= now {
|
|
go session.ProcessStatusUpdate(
|
|
chatID,
|
|
session.LastSeenStatus(delayedStatus.TimestampOnline),
|
|
"away",
|
|
)
|
|
delete(session.DelayedStatuses, chatID)
|
|
}
|
|
}
|
|
session.DelayedStatusesLock.Unlock()
|
|
}
|
|
sessionLock.Unlock()
|
|
|
|
for key, presence := range gateway.Queue {
|
|
err = gateway.ResumableSend(component, presence)
|
|
if err != nil {
|
|
gateway.LogBadPresence(presence)
|
|
} else {
|
|
gateway.QueueLock.Lock()
|
|
delete(gateway.Queue, key)
|
|
gateway.QueueLock.Unlock()
|
|
}
|
|
}
|
|
|
|
if gateway.DirtySessions {
|
|
gateway.DirtySessions = false
|
|
// no problem if a dirty flag gets set again here,
|
|
// it would be resolved on the next iteration
|
|
SaveSessions()
|
|
}
|
|
}
|
|
}
|
|
|
|
func loadSessions(dbPath string, component *xmpp.Component) error {
|
|
var err error
|
|
|
|
sessions = make(map[string]*telegram.Client)
|
|
|
|
db, err = persistence.LoadSessions(dbPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
db.Transaction(func() bool {
|
|
for jid, session := range db.Data.Sessions {
|
|
// copy the session struct, otherwise all of them would reference
|
|
// the same temporary range variable
|
|
currentSession := session
|
|
getTelegramInstance(jid, ¤tSession, component)
|
|
}
|
|
|
|
return false
|
|
}, persistence.SessionMarshaller)
|
|
|
|
return nil
|
|
}
|
|
|
|
func getTelegramInstance(jid string, savedSession *persistence.Session, component *xmpp.Component) (*telegram.Client, bool) {
|
|
var err error
|
|
session, ok := sessions[jid]
|
|
if !ok {
|
|
session, err = telegram.NewClient(tgConf, jid, component, savedSession)
|
|
if err != nil {
|
|
log.Error(errors.Wrap(err, "TDlib initialization failure"))
|
|
return session, false
|
|
}
|
|
if savedSession.KeepOnline {
|
|
if err = session.Connect(""); err != nil {
|
|
log.Error(err)
|
|
return session, false
|
|
}
|
|
}
|
|
sessionLock.Lock()
|
|
sessions[jid] = session
|
|
sessionLock.Unlock()
|
|
}
|
|
|
|
return session, true
|
|
}
|
|
|
|
// SaveSessions dumps current sessions to the file
|
|
func SaveSessions() {
|
|
sessionLock.Lock()
|
|
defer sessionLock.Unlock()
|
|
db.Transaction(func() bool {
|
|
for jid, session := range sessions {
|
|
db.Data.Sessions[jid] = *session.Session
|
|
}
|
|
|
|
return true
|
|
}, persistence.SessionMarshaller)
|
|
}
|
|
|
|
// Close gracefully terminates the component and saves active sessions
|
|
func Close(component *xmpp.Component) {
|
|
log.Error("Disconnecting...")
|
|
|
|
sessionLock.Lock()
|
|
// close all sessions
|
|
for _, session := range sessions {
|
|
session.Disconnect("", true)
|
|
}
|
|
sessionLock.Unlock()
|
|
|
|
// save sessions
|
|
SaveSessions()
|
|
|
|
// close stream
|
|
component.Disconnect()
|
|
}
|
|
|
|
// based on https://github.com/c2h5oh/datasize/blob/master/datasize.go
|
|
func parseSize(sSize string) (uint64, error) {
|
|
sizeParts := sizeRegex.FindStringSubmatch(sSize)
|
|
|
|
if len(sizeParts) > 2 {
|
|
numPart, err := strconv.ParseInt(sizeParts[1], 10, 64)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
var divisor uint64
|
|
val := uint64(numPart)
|
|
|
|
if len(sizeParts[2]) > 0 {
|
|
switch sizeParts[2][0] {
|
|
case 'B':
|
|
divisor = 1
|
|
case 'K':
|
|
divisor = KB
|
|
case 'M':
|
|
divisor = MB
|
|
case 'G':
|
|
divisor = GB
|
|
case 'T':
|
|
divisor = TB
|
|
case 'P':
|
|
divisor = PB
|
|
case 'E':
|
|
divisor = EB
|
|
}
|
|
}
|
|
|
|
if divisor == 0 {
|
|
return 0, &strconv.NumError{"Wrong suffix", sSize, strconv.ErrSyntax}
|
|
}
|
|
if val > maxUint64/divisor {
|
|
return 0, &strconv.NumError{"Overflow", sSize, strconv.ErrRange}
|
|
}
|
|
return val * divisor, nil
|
|
}
|
|
|
|
return 0, &strconv.NumError{"Not enough parts", sSize, strconv.ErrSyntax}
|
|
}
|