summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--client/client.go4
-rw-r--r--cmd/daemon/main.go38
-rw-r--r--core/data.go9
-rw-r--r--core/payload.go6
-rw-r--r--server/channel.go48
-rw-r--r--server/message.go74
-rw-r--r--server/remote.go2
-rw-r--r--server/server.go30
-rw-r--r--server/storage.go1
-rw-r--r--server/user.go83
-rw-r--r--storage/storage.go37
-rwxr-xr-xtools/run-tls.sh8
-rwxr-xr-xtools/run-two-clients.sh8
-rwxr-xr-xtools/run.sh8
14 files changed, 195 insertions, 161 deletions
diff --git a/client/client.go b/client/client.go
index ac8ef67..8abbde2 100644
--- a/client/client.go
+++ b/client/client.go
@@ -60,7 +60,7 @@ func (c *Client) Connect() error {
return err
}
- auth := core.UserAuth{Name: c.cfg.User, Pass: c.cfg.Pass}
+ auth := core.UserAuth{Addr: c.cfg.User, Pass: c.cfg.Pass}
if err := core.Send(c.conn, auth); err != nil {
return err
}
@@ -71,7 +71,7 @@ func (c *Client) Connect() error {
func (c *Client) SendMessage(target, content string) error {
msg := core.Message{
- Source: c.cfg.User + "@" + c.cfg.Addr,
+ Source: c.cfg.User,
Target: target,
Content: content,
}
diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go
index a118df3..98dd565 100644
--- a/cmd/daemon/main.go
+++ b/cmd/daemon/main.go
@@ -40,10 +40,12 @@ func main() {
enableTls := flag.Bool("tls", false, "Enable TLS")
certPath := flag.String("tls-cert", "", "TLS certificate PEM file path")
keyPath := flag.String("tls-key", "", "TLS key PEM file path")
+ test := flag.Bool("test", false, "Create test database entries")
+ dbPath := flag.String("db", "solec.db", "SQLite database path")
flag.Parse()
- db, err := storage.InitDb("test.db")
+ db, err := storage.InitDb(*dbPath)
if err != nil {
panic(err)
}
@@ -59,7 +61,10 @@ func main() {
}
serv = server.NewServer(cfg, db)
- serv.AddChannel("test")
+
+ if *test {
+ seedDatabase()
+ }
go func() {
if err := serv.Start(); err != nil {
@@ -185,10 +190,37 @@ func setPerm(args []string) {
}
}
+func seedDatabase() {
+ log.Println("creating test users")
+ log.Println("user1, user2, user3")
+ setUser([]string{"user1@localhost", "test"})
+ setUser([]string{"user2@localhost", "test"})
+ setUser([]string{"user3@localhost", "test"})
+
+ log.Println("setting #test@localhost channel permissions")
+ log.Println("user1, user2 -> #test")
+ setPerm([]string{"user1@localhost", "#test@localhost", "1", "1"})
+ setPerm([]string{"user2@localhost", "#test@localhost", "1", "1"})
+
+ log.Println("setting user channel permissions")
+ log.Println("user1 -> user1, user2, user3")
+ setPerm([]string{"user1@localhost", "user1@localhost", "1", "1"})
+ setPerm([]string{"user1@localhost", "user2@localhost", "1", "1"})
+ setPerm([]string{"user1@localhost", "user3@localhost", "1", "1"})
+
+ log.Println("user2 -> user1, user2")
+ setPerm([]string{"user2@localhost", "user1@localhost", "1", "1"})
+ setPerm([]string{"user2@localhost", "user2@localhost", "1", "1"})
+
+ log.Println("user3 -> user1, user3")
+ setPerm([]string{"user3@localhost", "user1@localhost", "1", "1"})
+ setPerm([]string{"user3@localhost", "user3@localhost", "1", "1"})
+}
+
func exit(args []string) {
os.Exit(0)
}
func printErr(err error) {
- fmt.Println("error:", err)
+ log.Println("error:", err)
}
diff --git a/core/data.go b/core/data.go
index e3fd560..fdad424 100644
--- a/core/data.go
+++ b/core/data.go
@@ -236,3 +236,12 @@ func ReadAddr(addrStr string) (Addr, error) {
return addr, nil
}
+
+func (a Addr) String() string {
+ var prefix string
+ if a.Type == AddrGroup {
+ prefix = "#"
+ }
+
+ return fmt.Sprintf("%s%s@%s", prefix, a.Channel, a.Host)
+}
diff --git a/core/payload.go b/core/payload.go
index b54a8cd..3bc6038 100644
--- a/core/payload.go
+++ b/core/payload.go
@@ -67,19 +67,19 @@ func DecodeHandshake(buf io.Reader) (Handshake, error) {
}
type UserAuth struct {
- Name string
+ Addr string
Pass string
}
func (a UserAuth) Wrap() (PayloadType, []any) {
return PayloadUserAuth, []any{
- a.Name, a.Pass,
+ a.Addr, a.Pass,
}
}
func DecodeUserAuth(buf io.Reader) (UserAuth, error) {
var a UserAuth
- err := decodeString(buf, &a.Name)
+ err := decodeString(buf, &a.Addr)
if err != nil {
return a, err
}
diff --git a/server/channel.go b/server/channel.go
deleted file mode 100644
index 11969f2..0000000
--- a/server/channel.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package server
-
-import (
- "log"
- "net"
- "sync"
-
- "go.rctt.net/solec/core"
-)
-
-type Channel struct {
- Name string
- Users map[string]*User
- UsersMu sync.RWMutex
-}
-
-func NewChannel(name string) *Channel {
- return &Channel{
- Name: name,
- Users: make(map[string]*User),
- }
-}
-
-func (c *Channel) Add(u *User) {
- c.UsersMu.Lock()
- c.Users[u.Name] = u
- u.Channels[c.Name] = c
- c.UsersMu.Unlock()
-
- log.Println("user joined a channel")
-}
-
-func (c *Channel) Remove(u *User) {
- c.UsersMu.Lock()
- delete(c.Users, u.Name)
- delete(u.Channels, c.Name)
- c.UsersMu.Unlock()
-
- log.Println("user left a channel")
-}
-
-func (c *Channel) Send(senderConn net.Conn, msg core.Message) {
- for _, u := range c.Users {
- if err := u.Send(senderConn, msg); err != nil {
- log.Print("cannot send a message to user on channel", err)
- }
- }
-}
diff --git a/server/message.go b/server/message.go
index 19bfadf..801a4cf 100644
--- a/server/message.go
+++ b/server/message.go
@@ -30,7 +30,7 @@ func (s *Server) SendBroadcast(msg string) {
}
}
-func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, msg core.Message) error {
+func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, senderUser *User, msg core.Message) error {
if connType == core.ConnTypeUser {
msg.Timestamp = time.Now()
}
@@ -47,34 +47,60 @@ func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, msg core
}
if addr.Host == s.cfg.Name {
- return s.handleLocalMessage(sender, addr, msg)
+ return s.handleLocalMessage(sender, senderUser, addr, msg)
}
- return s.handleOutboundMessage(sender, addr, msg)
+ return s.handleOutboundMessage(addr, msg)
}
-func (s *Server) handleLocalMessage(sender net.Conn, addr core.Addr, msg core.Message) error {
- if addr.Type == core.AddrUser {
- s.usersMu.RLock()
- user, ok := s.users[addr.Channel]
- if !ok {
- return core.Send(sender, core.Error{core.ErrorNotFound})
- }
- s.usersMu.RUnlock()
- return user.Send(sender, msg)
+func (s *Server) handleLocalMessage(sender net.Conn, senderUser *User, addr core.Addr, msg core.Message) error {
+ perm, err := s.Storage.GetPermission(senderUser.Addr, addr.String())
+ if err != nil {
+ log.Println("cannot get channel permissions:", err)
+ return core.Send(sender, core.Error{core.ErrorNotFound})
}
- s.channelsMu.RLock()
- channel, ok := s.channels[addr.Channel]
- if !ok {
+ if !perm.Write {
+ log.Println("user not authorized")
return core.Send(sender, core.Error{core.ErrorNotFound})
}
- s.channelsMu.RUnlock()
- channel.Send(sender, msg)
+
+ if addr.Type == core.AddrUser {
+ s.handleUserMessage(addr, sender, msg)
+ return nil
+ }
+
+ users, err := s.Storage.GetChannelUsers(addr.String())
+ if err != nil {
+ log.Println("cannot get channel users:", err)
+ return core.Send(sender, core.Error{core.ErrorUnknown})
+ }
+
+ for _, u := range users {
+ addr, err := core.ReadAddr(u)
+ if err != nil {
+ log.Println("cannot read user address:", err)
+ continue
+ }
+
+ if addr.Host != s.cfg.Name {
+ err := s.handleOutboundMessage(addr, msg)
+ if err != nil {
+ log.Println("cannot send group message to remote user:", err)
+ }
+ continue
+ }
+
+ err = s.handleUserMessage(addr, sender, msg)
+ if err != nil {
+ log.Println("cannot send group message to local user:", err)
+ }
+ }
+
return nil
}
-func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core.Message) error {
+func (s *Server) handleOutboundMessage(addr core.Addr, msg core.Message) error {
remote, err := s.getRemote(addr.Host)
if err != nil {
return fmt.Errorf("cannot access remote server: %w", err)
@@ -82,3 +108,15 @@ func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core
return core.Send(remote.Conn, msg)
}
+
+func (s *Server) handleUserMessage(addr core.Addr, sender net.Conn, msg core.Message) error {
+ s.usersMu.RLock()
+ user, ok := s.users[addr.String()]
+ if !ok {
+ log.Println("user not found")
+ return core.Send(sender, core.Error{core.ErrorNotFound})
+ }
+ s.usersMu.RUnlock()
+
+ return user.Send(sender, msg)
+}
diff --git a/server/remote.go b/server/remote.go
index e1829b1..70e4734 100644
--- a/server/remote.go
+++ b/server/remote.go
@@ -126,7 +126,7 @@ func (s *Server) readRemoteInput(conn net.Conn) error {
func (s *Server) handleRemotePayload(sender net.Conn, payload any) error {
switch v := payload.(type) {
case core.Message:
- return s.handleMessage(sender, core.ConnTypeServer, v)
+ return s.handleMessage(sender, core.ConnTypeServer, nil, v)
default:
return core.ErrUnexpectedPayloadType
}
diff --git a/server/server.go b/server/server.go
index b5840df..af43e3a 100644
--- a/server/server.go
+++ b/server/server.go
@@ -11,14 +11,12 @@ import (
)
type Server struct {
- cfg Config
- users map[string]User // TODO: Use full address instead of just name
- servers map[string]RemoteServer
- channels map[string]*Channel
- usersMu sync.RWMutex
- serversMu sync.RWMutex
- channelsMu sync.RWMutex
- Storage Storage
+ cfg Config
+ users map[string]User // TODO: Use full address instead of just name
+ servers map[string]RemoteServer
+ usersMu sync.RWMutex
+ serversMu sync.RWMutex
+ Storage Storage
}
type Config struct {
@@ -31,11 +29,10 @@ type Config struct {
func NewServer(cfg Config, storage Storage) *Server {
return &Server{
- cfg: cfg,
- Storage: storage,
- users: make(map[string]User),
- servers: make(map[string]RemoteServer),
- channels: make(map[string]*Channel),
+ cfg: cfg,
+ Storage: storage,
+ users: make(map[string]User),
+ servers: make(map[string]RemoteServer),
}
}
@@ -47,13 +44,6 @@ func (s *Server) Start() error {
return s.listenPlain()
}
-func (s *Server) AddChannel(name string) {
- s.channelsMu.Lock()
- defer s.channelsMu.Unlock()
- s.channels[name] = NewChannel(name)
- log.Println("created channel", name)
-}
-
func (s *Server) listenPlain() error {
ln, err := net.Listen("tcp", s.cfg.ListenAddr)
if err != nil {
diff --git a/server/storage.go b/server/storage.go
index 037c40f..1cb725f 100644
--- a/server/storage.go
+++ b/server/storage.go
@@ -16,4 +16,5 @@ type Storage interface {
SetPermission(data core.PermissionData) error
GetPermission(user, channel string) (core.PermissionData, error)
+ GetChannelUsers(channel string) ([]string, error)
}
diff --git a/server/user.go b/server/user.go
index 6ca857a..c615c40 100644
--- a/server/user.go
+++ b/server/user.go
@@ -9,16 +9,14 @@ import (
)
type User struct {
- Name string
- Conns map[net.Conn]struct{}
- Channels map[string]*Channel
+ Addr string
+ Conns map[net.Conn]struct{}
}
-func NewUser(conn net.Conn, name string) User {
+func NewUser(conn net.Conn, addr string) User {
u := User{
- Name: name,
- Conns: make(map[net.Conn]struct{}),
- Channels: make(map[string]*Channel),
+ Addr: addr,
+ Conns: make(map[net.Conn]struct{}),
}
u.Conns[conn] = struct{}{}
@@ -39,31 +37,31 @@ func (u *User) Send(senderConn net.Conn, payload core.Wrapper) error {
}
func (s *Server) handleUserConn(conn net.Conn) {
- name, err := s.performUserAuth(conn)
+ addr, err := s.performUserAuth(conn)
if err != nil {
log.Println("user auth error:", err)
return
}
s.usersMu.Lock()
- user, ok := s.users[name]
+ user, ok := s.users[addr]
if ok {
- log.Println("next connection from user:", user.Name)
+ log.Println("next connection from user:", user.Addr)
user.Conns[conn] = struct{}{}
} else {
- log.Println("initial connection from user:", name)
- user = NewUser(conn, name)
- s.users[name] = user
+ log.Println("initial connection from user:", addr)
+ user = NewUser(conn, addr)
+ s.users[addr] = user
}
s.usersMu.Unlock()
defer func() {
s.usersMu.Lock()
- log.Println("client disconnected: ", user.Name)
- delete(s.users[user.Name].Conns, conn)
- if len(s.users[user.Name].Conns) == 0 {
- log.Println("all connections closed for user:", user.Name)
- delete(s.users, user.Name)
+ log.Println("client disconnected: ", user.Addr)
+ delete(s.users[user.Addr].Conns, conn)
+ if len(s.users[user.Addr].Conns) == 0 {
+ log.Println("all connections closed for user:", user.Addr)
+ delete(s.users, user.Addr)
}
s.usersMu.Unlock()
}()
@@ -84,7 +82,7 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) {
return "", core.ErrUnexpectedPayloadType
}
- hash, err := s.Storage.GetUserPass(clientAuth.Name)
+ hash, err := s.Storage.GetUserPass(clientAuth.Addr)
if err != nil {
s.authFail(conn)
return "", core.ErrAuthInvalidUser
@@ -99,7 +97,7 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) {
return "", err
}
- return clientAuth.Name, nil
+ return clientAuth.Addr, nil
}
func (s *Server) authFail(conn net.Conn) {
@@ -123,7 +121,7 @@ func (s *Server) readUserInput(user *User, conn net.Conn) error {
func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) error {
switch v := payload.(type) {
case core.Message:
- return s.handleMessage(sender, core.ConnTypeUser, v)
+ return s.handleMessage(sender, core.ConnTypeUser, user, v)
case core.Usermode:
return s.handleUsermode(user, sender, v)
case core.History:
@@ -134,36 +132,25 @@ func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) err
}
func (s *Server) handleUsermode(user *User, conn net.Conn, mode core.Usermode) error {
- userAddr, err := core.ReadAddr(mode.UserAddr)
- if err != nil {
- return err
- }
-
- chanAddr, err := core.ReadAddr(mode.ChannelName)
- if err != nil {
- return err
- }
- if user.Name != userAddr.Channel {
- log.Println("unauthorized")
- return user.Send(conn, core.Error{core.ErrorUnauthorized})
- }
+ /*
+ userAddr, err := core.ReadAddr(mode.UserAddr)
+ if err != nil {
+ return err
+ }
- s.channelsMu.RLock()
- channel, ok := s.channels[chanAddr.Channel]
- if !ok {
- log.Println("not found", userAddr.Channel)
- return user.Send(conn, core.Error{core.ErrorNotFound})
- }
- s.channelsMu.RUnlock()
+ chanAddr, err := core.ReadAddr(mode.ChannelAddr)
+ if err != nil {
+ return err
+ }
- switch mode.Mode {
- case core.UsermodeNone:
- channel.Remove(user)
- case core.UsermodeInChannel:
- channel.Add(user)
- }
+ if user.Addr != userAddr.Channel {
+ log.Println("unauthorized")
+ return user.Send(conn, core.Error{core.ErrorUnauthorized})
+ }
+ // TODO: change user permissions here
+ */
return nil
}
@@ -177,7 +164,7 @@ func (s *Server) handleHistory(user *User, conn net.Conn, hist core.History) err
return user.Send(conn, core.Error{core.ErrorNotFound})
}
- perm, err := s.Storage.GetPermission(user.Name+"@"+s.cfg.Name, hist.Channel)
+ perm, err := s.Storage.GetPermission(user.Addr, hist.Channel)
if err != nil {
fmt.Println("cannot get message history:", err)
return user.Send(conn, core.Error{core.ErrorNotFound})
diff --git a/storage/storage.go b/storage/storage.go
index 510b587..1292945 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -17,7 +17,7 @@ const initSql = `
(id INTEGER NOT NULL PRIMARY KEY, source STRING, target STRING, timestamp INT, content STRING);
CREATE TABLE IF NOT EXISTS users
- (name STRING NOT NULL PRIMARY KEY, pass STRING);
+ (addr STRING NOT NULL PRIMARY KEY, pass STRING);
CREATE TABLE IF NOT EXISTS permissions
(user STRING, channel STRING, read INT, write INT, PRIMARY KEY (user, channel));
@@ -76,22 +76,22 @@ func (db *Database) GetHistory(channel string, since time.Time, num int, offset
func (db *Database) SetUser(user core.UserData) error {
_, err := db.Exec(
- "INSERT OR REPLACE INTO users (name, pass) VALUES (?, ?);",
+ "INSERT OR REPLACE INTO users (addr, pass) VALUES (?, ?);",
user.Name, user.Pass,
)
return err
}
-func (db *Database) DelUser(name string) error {
- _, err := db.Exec("DELETE FROM users WHERE name = ?", name)
+func (db *Database) DelUser(addr string) error {
+ _, err := db.Exec("DELETE FROM users WHERE addr = ?", addr)
return err
}
-func (db *Database) GetUserPass(name string) (string, error) {
+func (db *Database) GetUserPass(addr string) (string, error) {
var pass string
- err := db.QueryRow("SELECT pass FROM users WHERE name = ?", name).Scan(&pass)
+ err := db.QueryRow("SELECT pass FROM users WHERE addr = ?", addr).Scan(&pass)
if err != nil {
return "", err
}
@@ -126,6 +126,31 @@ func (db *Database) GetPermission(user, channel string) (core.PermissionData, er
}, nil
}
+func (db *Database) GetChannelUsers(channel string) (users []string, err error) {
+ rows, err := db.Query("SELECT user FROM permissions WHERE channel = ? AND write = 1;", channel)
+ defer func() {
+ if rows == nil {
+ return
+ }
+ if err := rows.Close(); err != nil {
+ log.Println("cannot close database row:", err)
+ }
+ }()
+ if err != nil {
+ return users, err
+ }
+
+ for rows.Next() {
+ var user string
+ if err := rows.Scan(&user); err != nil {
+ return users, err
+ }
+ users = append(users, user)
+ }
+
+ return users, nil
+}
+
func itob(v int) bool {
if v == 1 {
return true
diff --git a/tools/run-tls.sh b/tools/run-tls.sh
index 894a9f7..c1b84ea 100755
--- a/tools/run-tls.sh
+++ b/tools/run-tls.sh
@@ -1,8 +1,8 @@
#!/bin/sh
tmux \
- new-session "go run cmd/daemon/main.go -tls -tls-cert cert.pem -tls-key key.pem; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-insecure -u user1; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-insecure -u user2; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-insecure -u user3; read" \; \
+ new-session "go run cmd/daemon/main.go -tls -tls-cert cert.pem -tls-key key.pem -test -db test.db; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -tls -tls-insecure -u user1; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -tls -tls-insecure -u user2; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -tls -tls-insecure -u user3; read" \; \
select-layout tiled; \ No newline at end of file
diff --git a/tools/run-two-clients.sh b/tools/run-two-clients.sh
index 884a88a..d5d9294 100755
--- a/tools/run-two-clients.sh
+++ b/tools/run-two-clients.sh
@@ -1,8 +1,8 @@
#!/bin/sh
tmux \
- new-session "go run cmd/daemon/main.go; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -u user1; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -u user1; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -u user2; read" \; \
+ new-session "go run cmd/daemon/main.go -test -db test.db; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -u user1@localhost -p test; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -u user1@localhost -p test; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -u user2@localhost -p test; read" \; \
select-layout tiled; \ No newline at end of file
diff --git a/tools/run.sh b/tools/run.sh
index dbe0bf0..2245c90 100755
--- a/tools/run.sh
+++ b/tools/run.sh
@@ -1,8 +1,8 @@
#!/bin/sh
tmux \
- new-session "go run cmd/daemon/main.go; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -u user1; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -u user2; read" \; \
- split-window "sleep 0.5; go run cmd/client/main.go -u user3; read" \; \
+ new-session "go run cmd/daemon/main.go -test -db test.db; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -u user1@localhost -p test; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -u user2@localhost -p test; read" \; \
+ split-window "sleep 1; go run cmd/client/main.go -u user3@localhost -p test; read" \; \
select-layout tiled; \ No newline at end of file