diff options
| -rw-r--r-- | client/client.go | 4 | ||||
| -rw-r--r-- | cmd/daemon/main.go | 38 | ||||
| -rw-r--r-- | core/data.go | 9 | ||||
| -rw-r--r-- | core/payload.go | 6 | ||||
| -rw-r--r-- | server/channel.go | 48 | ||||
| -rw-r--r-- | server/message.go | 74 | ||||
| -rw-r--r-- | server/remote.go | 2 | ||||
| -rw-r--r-- | server/server.go | 30 | ||||
| -rw-r--r-- | server/storage.go | 1 | ||||
| -rw-r--r-- | server/user.go | 83 | ||||
| -rw-r--r-- | storage/storage.go | 37 | ||||
| -rwxr-xr-x | tools/run-tls.sh | 8 | ||||
| -rwxr-xr-x | tools/run-two-clients.sh | 8 | ||||
| -rwxr-xr-x | tools/run.sh | 8 |
14 files changed, 195 insertions, 161 deletions
diff --git a/client/client.go b/client/client.go index ac8ef67..8abbde2 100644 --- a/client/client.go +++ b/client/client.go @@ -60,7 +60,7 @@ func (c *Client) Connect() error { return err } - auth := core.UserAuth{Name: c.cfg.User, Pass: c.cfg.Pass} + auth := core.UserAuth{Addr: c.cfg.User, Pass: c.cfg.Pass} if err := core.Send(c.conn, auth); err != nil { return err } @@ -71,7 +71,7 @@ func (c *Client) Connect() error { func (c *Client) SendMessage(target, content string) error { msg := core.Message{ - Source: c.cfg.User + "@" + c.cfg.Addr, + Source: c.cfg.User, Target: target, Content: content, } diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index a118df3..98dd565 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -40,10 +40,12 @@ func main() { enableTls := flag.Bool("tls", false, "Enable TLS") certPath := flag.String("tls-cert", "", "TLS certificate PEM file path") keyPath := flag.String("tls-key", "", "TLS key PEM file path") + test := flag.Bool("test", false, "Create test database entries") + dbPath := flag.String("db", "solec.db", "SQLite database path") flag.Parse() - db, err := storage.InitDb("test.db") + db, err := storage.InitDb(*dbPath) if err != nil { panic(err) } @@ -59,7 +61,10 @@ func main() { } serv = server.NewServer(cfg, db) - serv.AddChannel("test") + + if *test { + seedDatabase() + } go func() { if err := serv.Start(); err != nil { @@ -185,10 +190,37 @@ func setPerm(args []string) { } } +func seedDatabase() { + log.Println("creating test users") + log.Println("user1, user2, user3") + setUser([]string{"user1@localhost", "test"}) + setUser([]string{"user2@localhost", "test"}) + setUser([]string{"user3@localhost", "test"}) + + log.Println("setting #test@localhost channel permissions") + log.Println("user1, user2 -> #test") + setPerm([]string{"user1@localhost", "#test@localhost", "1", "1"}) + setPerm([]string{"user2@localhost", "#test@localhost", "1", "1"}) + + log.Println("setting user channel permissions") + log.Println("user1 -> user1, user2, user3") + setPerm([]string{"user1@localhost", "user1@localhost", "1", "1"}) + setPerm([]string{"user1@localhost", "user2@localhost", "1", "1"}) + setPerm([]string{"user1@localhost", "user3@localhost", "1", "1"}) + + log.Println("user2 -> user1, user2") + setPerm([]string{"user2@localhost", "user1@localhost", "1", "1"}) + setPerm([]string{"user2@localhost", "user2@localhost", "1", "1"}) + + log.Println("user3 -> user1, user3") + setPerm([]string{"user3@localhost", "user1@localhost", "1", "1"}) + setPerm([]string{"user3@localhost", "user3@localhost", "1", "1"}) +} + func exit(args []string) { os.Exit(0) } func printErr(err error) { - fmt.Println("error:", err) + log.Println("error:", err) } diff --git a/core/data.go b/core/data.go index e3fd560..fdad424 100644 --- a/core/data.go +++ b/core/data.go @@ -236,3 +236,12 @@ func ReadAddr(addrStr string) (Addr, error) { return addr, nil } + +func (a Addr) String() string { + var prefix string + if a.Type == AddrGroup { + prefix = "#" + } + + return fmt.Sprintf("%s%s@%s", prefix, a.Channel, a.Host) +} diff --git a/core/payload.go b/core/payload.go index b54a8cd..3bc6038 100644 --- a/core/payload.go +++ b/core/payload.go @@ -67,19 +67,19 @@ func DecodeHandshake(buf io.Reader) (Handshake, error) { } type UserAuth struct { - Name string + Addr string Pass string } func (a UserAuth) Wrap() (PayloadType, []any) { return PayloadUserAuth, []any{ - a.Name, a.Pass, + a.Addr, a.Pass, } } func DecodeUserAuth(buf io.Reader) (UserAuth, error) { var a UserAuth - err := decodeString(buf, &a.Name) + err := decodeString(buf, &a.Addr) if err != nil { return a, err } diff --git a/server/channel.go b/server/channel.go deleted file mode 100644 index 11969f2..0000000 --- a/server/channel.go +++ /dev/null @@ -1,48 +0,0 @@ -package server - -import ( - "log" - "net" - "sync" - - "go.rctt.net/solec/core" -) - -type Channel struct { - Name string - Users map[string]*User - UsersMu sync.RWMutex -} - -func NewChannel(name string) *Channel { - return &Channel{ - Name: name, - Users: make(map[string]*User), - } -} - -func (c *Channel) Add(u *User) { - c.UsersMu.Lock() - c.Users[u.Name] = u - u.Channels[c.Name] = c - c.UsersMu.Unlock() - - log.Println("user joined a channel") -} - -func (c *Channel) Remove(u *User) { - c.UsersMu.Lock() - delete(c.Users, u.Name) - delete(u.Channels, c.Name) - c.UsersMu.Unlock() - - log.Println("user left a channel") -} - -func (c *Channel) Send(senderConn net.Conn, msg core.Message) { - for _, u := range c.Users { - if err := u.Send(senderConn, msg); err != nil { - log.Print("cannot send a message to user on channel", err) - } - } -} diff --git a/server/message.go b/server/message.go index 19bfadf..801a4cf 100644 --- a/server/message.go +++ b/server/message.go @@ -30,7 +30,7 @@ func (s *Server) SendBroadcast(msg string) { } } -func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, msg core.Message) error { +func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, senderUser *User, msg core.Message) error { if connType == core.ConnTypeUser { msg.Timestamp = time.Now() } @@ -47,34 +47,60 @@ func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, msg core } if addr.Host == s.cfg.Name { - return s.handleLocalMessage(sender, addr, msg) + return s.handleLocalMessage(sender, senderUser, addr, msg) } - return s.handleOutboundMessage(sender, addr, msg) + return s.handleOutboundMessage(addr, msg) } -func (s *Server) handleLocalMessage(sender net.Conn, addr core.Addr, msg core.Message) error { - if addr.Type == core.AddrUser { - s.usersMu.RLock() - user, ok := s.users[addr.Channel] - if !ok { - return core.Send(sender, core.Error{core.ErrorNotFound}) - } - s.usersMu.RUnlock() - return user.Send(sender, msg) +func (s *Server) handleLocalMessage(sender net.Conn, senderUser *User, addr core.Addr, msg core.Message) error { + perm, err := s.Storage.GetPermission(senderUser.Addr, addr.String()) + if err != nil { + log.Println("cannot get channel permissions:", err) + return core.Send(sender, core.Error{core.ErrorNotFound}) } - s.channelsMu.RLock() - channel, ok := s.channels[addr.Channel] - if !ok { + if !perm.Write { + log.Println("user not authorized") return core.Send(sender, core.Error{core.ErrorNotFound}) } - s.channelsMu.RUnlock() - channel.Send(sender, msg) + + if addr.Type == core.AddrUser { + s.handleUserMessage(addr, sender, msg) + return nil + } + + users, err := s.Storage.GetChannelUsers(addr.String()) + if err != nil { + log.Println("cannot get channel users:", err) + return core.Send(sender, core.Error{core.ErrorUnknown}) + } + + for _, u := range users { + addr, err := core.ReadAddr(u) + if err != nil { + log.Println("cannot read user address:", err) + continue + } + + if addr.Host != s.cfg.Name { + err := s.handleOutboundMessage(addr, msg) + if err != nil { + log.Println("cannot send group message to remote user:", err) + } + continue + } + + err = s.handleUserMessage(addr, sender, msg) + if err != nil { + log.Println("cannot send group message to local user:", err) + } + } + return nil } -func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core.Message) error { +func (s *Server) handleOutboundMessage(addr core.Addr, msg core.Message) error { remote, err := s.getRemote(addr.Host) if err != nil { return fmt.Errorf("cannot access remote server: %w", err) @@ -82,3 +108,15 @@ func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core return core.Send(remote.Conn, msg) } + +func (s *Server) handleUserMessage(addr core.Addr, sender net.Conn, msg core.Message) error { + s.usersMu.RLock() + user, ok := s.users[addr.String()] + if !ok { + log.Println("user not found") + return core.Send(sender, core.Error{core.ErrorNotFound}) + } + s.usersMu.RUnlock() + + return user.Send(sender, msg) +} diff --git a/server/remote.go b/server/remote.go index e1829b1..70e4734 100644 --- a/server/remote.go +++ b/server/remote.go @@ -126,7 +126,7 @@ func (s *Server) readRemoteInput(conn net.Conn) error { func (s *Server) handleRemotePayload(sender net.Conn, payload any) error { switch v := payload.(type) { case core.Message: - return s.handleMessage(sender, core.ConnTypeServer, v) + return s.handleMessage(sender, core.ConnTypeServer, nil, v) default: return core.ErrUnexpectedPayloadType } diff --git a/server/server.go b/server/server.go index b5840df..af43e3a 100644 --- a/server/server.go +++ b/server/server.go @@ -11,14 +11,12 @@ import ( ) type Server struct { - cfg Config - users map[string]User // TODO: Use full address instead of just name - servers map[string]RemoteServer - channels map[string]*Channel - usersMu sync.RWMutex - serversMu sync.RWMutex - channelsMu sync.RWMutex - Storage Storage + cfg Config + users map[string]User // TODO: Use full address instead of just name + servers map[string]RemoteServer + usersMu sync.RWMutex + serversMu sync.RWMutex + Storage Storage } type Config struct { @@ -31,11 +29,10 @@ type Config struct { func NewServer(cfg Config, storage Storage) *Server { return &Server{ - cfg: cfg, - Storage: storage, - users: make(map[string]User), - servers: make(map[string]RemoteServer), - channels: make(map[string]*Channel), + cfg: cfg, + Storage: storage, + users: make(map[string]User), + servers: make(map[string]RemoteServer), } } @@ -47,13 +44,6 @@ func (s *Server) Start() error { return s.listenPlain() } -func (s *Server) AddChannel(name string) { - s.channelsMu.Lock() - defer s.channelsMu.Unlock() - s.channels[name] = NewChannel(name) - log.Println("created channel", name) -} - func (s *Server) listenPlain() error { ln, err := net.Listen("tcp", s.cfg.ListenAddr) if err != nil { diff --git a/server/storage.go b/server/storage.go index 037c40f..1cb725f 100644 --- a/server/storage.go +++ b/server/storage.go @@ -16,4 +16,5 @@ type Storage interface { SetPermission(data core.PermissionData) error GetPermission(user, channel string) (core.PermissionData, error) + GetChannelUsers(channel string) ([]string, error) } diff --git a/server/user.go b/server/user.go index 6ca857a..c615c40 100644 --- a/server/user.go +++ b/server/user.go @@ -9,16 +9,14 @@ import ( ) type User struct { - Name string - Conns map[net.Conn]struct{} - Channels map[string]*Channel + Addr string + Conns map[net.Conn]struct{} } -func NewUser(conn net.Conn, name string) User { +func NewUser(conn net.Conn, addr string) User { u := User{ - Name: name, - Conns: make(map[net.Conn]struct{}), - Channels: make(map[string]*Channel), + Addr: addr, + Conns: make(map[net.Conn]struct{}), } u.Conns[conn] = struct{}{} @@ -39,31 +37,31 @@ func (u *User) Send(senderConn net.Conn, payload core.Wrapper) error { } func (s *Server) handleUserConn(conn net.Conn) { - name, err := s.performUserAuth(conn) + addr, err := s.performUserAuth(conn) if err != nil { log.Println("user auth error:", err) return } s.usersMu.Lock() - user, ok := s.users[name] + user, ok := s.users[addr] if ok { - log.Println("next connection from user:", user.Name) + log.Println("next connection from user:", user.Addr) user.Conns[conn] = struct{}{} } else { - log.Println("initial connection from user:", name) - user = NewUser(conn, name) - s.users[name] = user + log.Println("initial connection from user:", addr) + user = NewUser(conn, addr) + s.users[addr] = user } s.usersMu.Unlock() defer func() { s.usersMu.Lock() - log.Println("client disconnected: ", user.Name) - delete(s.users[user.Name].Conns, conn) - if len(s.users[user.Name].Conns) == 0 { - log.Println("all connections closed for user:", user.Name) - delete(s.users, user.Name) + log.Println("client disconnected: ", user.Addr) + delete(s.users[user.Addr].Conns, conn) + if len(s.users[user.Addr].Conns) == 0 { + log.Println("all connections closed for user:", user.Addr) + delete(s.users, user.Addr) } s.usersMu.Unlock() }() @@ -84,7 +82,7 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) { return "", core.ErrUnexpectedPayloadType } - hash, err := s.Storage.GetUserPass(clientAuth.Name) + hash, err := s.Storage.GetUserPass(clientAuth.Addr) if err != nil { s.authFail(conn) return "", core.ErrAuthInvalidUser @@ -99,7 +97,7 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) { return "", err } - return clientAuth.Name, nil + return clientAuth.Addr, nil } func (s *Server) authFail(conn net.Conn) { @@ -123,7 +121,7 @@ func (s *Server) readUserInput(user *User, conn net.Conn) error { func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) error { switch v := payload.(type) { case core.Message: - return s.handleMessage(sender, core.ConnTypeUser, v) + return s.handleMessage(sender, core.ConnTypeUser, user, v) case core.Usermode: return s.handleUsermode(user, sender, v) case core.History: @@ -134,36 +132,25 @@ func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) err } func (s *Server) handleUsermode(user *User, conn net.Conn, mode core.Usermode) error { - userAddr, err := core.ReadAddr(mode.UserAddr) - if err != nil { - return err - } - - chanAddr, err := core.ReadAddr(mode.ChannelName) - if err != nil { - return err - } - if user.Name != userAddr.Channel { - log.Println("unauthorized") - return user.Send(conn, core.Error{core.ErrorUnauthorized}) - } + /* + userAddr, err := core.ReadAddr(mode.UserAddr) + if err != nil { + return err + } - s.channelsMu.RLock() - channel, ok := s.channels[chanAddr.Channel] - if !ok { - log.Println("not found", userAddr.Channel) - return user.Send(conn, core.Error{core.ErrorNotFound}) - } - s.channelsMu.RUnlock() + chanAddr, err := core.ReadAddr(mode.ChannelAddr) + if err != nil { + return err + } - switch mode.Mode { - case core.UsermodeNone: - channel.Remove(user) - case core.UsermodeInChannel: - channel.Add(user) - } + if user.Addr != userAddr.Channel { + log.Println("unauthorized") + return user.Send(conn, core.Error{core.ErrorUnauthorized}) + } + // TODO: change user permissions here + */ return nil } @@ -177,7 +164,7 @@ func (s *Server) handleHistory(user *User, conn net.Conn, hist core.History) err return user.Send(conn, core.Error{core.ErrorNotFound}) } - perm, err := s.Storage.GetPermission(user.Name+"@"+s.cfg.Name, hist.Channel) + perm, err := s.Storage.GetPermission(user.Addr, hist.Channel) if err != nil { fmt.Println("cannot get message history:", err) return user.Send(conn, core.Error{core.ErrorNotFound}) diff --git a/storage/storage.go b/storage/storage.go index 510b587..1292945 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -17,7 +17,7 @@ const initSql = ` (id INTEGER NOT NULL PRIMARY KEY, source STRING, target STRING, timestamp INT, content STRING); CREATE TABLE IF NOT EXISTS users - (name STRING NOT NULL PRIMARY KEY, pass STRING); + (addr STRING NOT NULL PRIMARY KEY, pass STRING); CREATE TABLE IF NOT EXISTS permissions (user STRING, channel STRING, read INT, write INT, PRIMARY KEY (user, channel)); @@ -76,22 +76,22 @@ func (db *Database) GetHistory(channel string, since time.Time, num int, offset func (db *Database) SetUser(user core.UserData) error { _, err := db.Exec( - "INSERT OR REPLACE INTO users (name, pass) VALUES (?, ?);", + "INSERT OR REPLACE INTO users (addr, pass) VALUES (?, ?);", user.Name, user.Pass, ) return err } -func (db *Database) DelUser(name string) error { - _, err := db.Exec("DELETE FROM users WHERE name = ?", name) +func (db *Database) DelUser(addr string) error { + _, err := db.Exec("DELETE FROM users WHERE addr = ?", addr) return err } -func (db *Database) GetUserPass(name string) (string, error) { +func (db *Database) GetUserPass(addr string) (string, error) { var pass string - err := db.QueryRow("SELECT pass FROM users WHERE name = ?", name).Scan(&pass) + err := db.QueryRow("SELECT pass FROM users WHERE addr = ?", addr).Scan(&pass) if err != nil { return "", err } @@ -126,6 +126,31 @@ func (db *Database) GetPermission(user, channel string) (core.PermissionData, er }, nil } +func (db *Database) GetChannelUsers(channel string) (users []string, err error) { + rows, err := db.Query("SELECT user FROM permissions WHERE channel = ? AND write = 1;", channel) + defer func() { + if rows == nil { + return + } + if err := rows.Close(); err != nil { + log.Println("cannot close database row:", err) + } + }() + if err != nil { + return users, err + } + + for rows.Next() { + var user string + if err := rows.Scan(&user); err != nil { + return users, err + } + users = append(users, user) + } + + return users, nil +} + func itob(v int) bool { if v == 1 { return true diff --git a/tools/run-tls.sh b/tools/run-tls.sh index 894a9f7..c1b84ea 100755 --- a/tools/run-tls.sh +++ b/tools/run-tls.sh @@ -1,8 +1,8 @@ #!/bin/sh tmux \ - new-session "go run cmd/daemon/main.go -tls -tls-cert cert.pem -tls-key key.pem; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-insecure -u user1; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-insecure -u user2; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -tls -tls-insecure -u user3; read" \; \ + new-session "go run cmd/daemon/main.go -tls -tls-cert cert.pem -tls-key key.pem -test -db test.db; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -tls -tls-insecure -u user1; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -tls -tls-insecure -u user2; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -tls -tls-insecure -u user3; read" \; \ select-layout tiled;
\ No newline at end of file diff --git a/tools/run-two-clients.sh b/tools/run-two-clients.sh index 884a88a..d5d9294 100755 --- a/tools/run-two-clients.sh +++ b/tools/run-two-clients.sh @@ -1,8 +1,8 @@ #!/bin/sh tmux \ - new-session "go run cmd/daemon/main.go; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -u user1; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -u user1; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -u user2; read" \; \ + new-session "go run cmd/daemon/main.go -test -db test.db; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -u user1@localhost -p test; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -u user1@localhost -p test; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -u user2@localhost -p test; read" \; \ select-layout tiled;
\ No newline at end of file diff --git a/tools/run.sh b/tools/run.sh index dbe0bf0..2245c90 100755 --- a/tools/run.sh +++ b/tools/run.sh @@ -1,8 +1,8 @@ #!/bin/sh tmux \ - new-session "go run cmd/daemon/main.go; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -u user1; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -u user2; read" \; \ - split-window "sleep 0.5; go run cmd/client/main.go -u user3; read" \; \ + new-session "go run cmd/daemon/main.go -test -db test.db; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -u user1@localhost -p test; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -u user2@localhost -p test; read" \; \ + split-window "sleep 1; go run cmd/client/main.go -u user3@localhost -p test; read" \; \ select-layout tiled;
\ No newline at end of file |
