summaryrefslogtreecommitdiffstats
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/channel.go48
-rw-r--r--server/message.go74
-rw-r--r--server/remote.go2
-rw-r--r--server/server.go30
-rw-r--r--server/storage.go1
-rw-r--r--server/user.go83
6 files changed, 103 insertions, 135 deletions
diff --git a/server/channel.go b/server/channel.go
deleted file mode 100644
index 11969f2..0000000
--- a/server/channel.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package server
-
-import (
- "log"
- "net"
- "sync"
-
- "go.rctt.net/solec/core"
-)
-
-type Channel struct {
- Name string
- Users map[string]*User
- UsersMu sync.RWMutex
-}
-
-func NewChannel(name string) *Channel {
- return &Channel{
- Name: name,
- Users: make(map[string]*User),
- }
-}
-
-func (c *Channel) Add(u *User) {
- c.UsersMu.Lock()
- c.Users[u.Name] = u
- u.Channels[c.Name] = c
- c.UsersMu.Unlock()
-
- log.Println("user joined a channel")
-}
-
-func (c *Channel) Remove(u *User) {
- c.UsersMu.Lock()
- delete(c.Users, u.Name)
- delete(u.Channels, c.Name)
- c.UsersMu.Unlock()
-
- log.Println("user left a channel")
-}
-
-func (c *Channel) Send(senderConn net.Conn, msg core.Message) {
- for _, u := range c.Users {
- if err := u.Send(senderConn, msg); err != nil {
- log.Print("cannot send a message to user on channel", err)
- }
- }
-}
diff --git a/server/message.go b/server/message.go
index 19bfadf..801a4cf 100644
--- a/server/message.go
+++ b/server/message.go
@@ -30,7 +30,7 @@ func (s *Server) SendBroadcast(msg string) {
}
}
-func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, msg core.Message) error {
+func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, senderUser *User, msg core.Message) error {
if connType == core.ConnTypeUser {
msg.Timestamp = time.Now()
}
@@ -47,34 +47,60 @@ func (s *Server) handleMessage(sender net.Conn, connType core.ConnType, msg core
}
if addr.Host == s.cfg.Name {
- return s.handleLocalMessage(sender, addr, msg)
+ return s.handleLocalMessage(sender, senderUser, addr, msg)
}
- return s.handleOutboundMessage(sender, addr, msg)
+ return s.handleOutboundMessage(addr, msg)
}
-func (s *Server) handleLocalMessage(sender net.Conn, addr core.Addr, msg core.Message) error {
- if addr.Type == core.AddrUser {
- s.usersMu.RLock()
- user, ok := s.users[addr.Channel]
- if !ok {
- return core.Send(sender, core.Error{core.ErrorNotFound})
- }
- s.usersMu.RUnlock()
- return user.Send(sender, msg)
+func (s *Server) handleLocalMessage(sender net.Conn, senderUser *User, addr core.Addr, msg core.Message) error {
+ perm, err := s.Storage.GetPermission(senderUser.Addr, addr.String())
+ if err != nil {
+ log.Println("cannot get channel permissions:", err)
+ return core.Send(sender, core.Error{core.ErrorNotFound})
}
- s.channelsMu.RLock()
- channel, ok := s.channels[addr.Channel]
- if !ok {
+ if !perm.Write {
+ log.Println("user not authorized")
return core.Send(sender, core.Error{core.ErrorNotFound})
}
- s.channelsMu.RUnlock()
- channel.Send(sender, msg)
+
+ if addr.Type == core.AddrUser {
+ s.handleUserMessage(addr, sender, msg)
+ return nil
+ }
+
+ users, err := s.Storage.GetChannelUsers(addr.String())
+ if err != nil {
+ log.Println("cannot get channel users:", err)
+ return core.Send(sender, core.Error{core.ErrorUnknown})
+ }
+
+ for _, u := range users {
+ addr, err := core.ReadAddr(u)
+ if err != nil {
+ log.Println("cannot read user address:", err)
+ continue
+ }
+
+ if addr.Host != s.cfg.Name {
+ err := s.handleOutboundMessage(addr, msg)
+ if err != nil {
+ log.Println("cannot send group message to remote user:", err)
+ }
+ continue
+ }
+
+ err = s.handleUserMessage(addr, sender, msg)
+ if err != nil {
+ log.Println("cannot send group message to local user:", err)
+ }
+ }
+
return nil
}
-func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core.Message) error {
+func (s *Server) handleOutboundMessage(addr core.Addr, msg core.Message) error {
remote, err := s.getRemote(addr.Host)
if err != nil {
return fmt.Errorf("cannot access remote server: %w", err)
@@ -82,3 +108,15 @@ func (s *Server) handleOutboundMessage(sender net.Conn, addr core.Addr, msg core
return core.Send(remote.Conn, msg)
}
+
+func (s *Server) handleUserMessage(addr core.Addr, sender net.Conn, msg core.Message) error {
+ s.usersMu.RLock()
+ user, ok := s.users[addr.String()]
+ if !ok {
+ log.Println("user not found")
+ return core.Send(sender, core.Error{core.ErrorNotFound})
+ }
+ s.usersMu.RUnlock()
+
+ return user.Send(sender, msg)
+}
diff --git a/server/remote.go b/server/remote.go
index e1829b1..70e4734 100644
--- a/server/remote.go
+++ b/server/remote.go
@@ -126,7 +126,7 @@ func (s *Server) readRemoteInput(conn net.Conn) error {
func (s *Server) handleRemotePayload(sender net.Conn, payload any) error {
switch v := payload.(type) {
case core.Message:
- return s.handleMessage(sender, core.ConnTypeServer, v)
+ return s.handleMessage(sender, core.ConnTypeServer, nil, v)
default:
return core.ErrUnexpectedPayloadType
}
diff --git a/server/server.go b/server/server.go
index b5840df..af43e3a 100644
--- a/server/server.go
+++ b/server/server.go
@@ -11,14 +11,12 @@ import (
)
type Server struct {
- cfg Config
- users map[string]User // TODO: Use full address instead of just name
- servers map[string]RemoteServer
- channels map[string]*Channel
- usersMu sync.RWMutex
- serversMu sync.RWMutex
- channelsMu sync.RWMutex
- Storage Storage
+ cfg Config
+ users map[string]User // TODO: Use full address instead of just name
+ servers map[string]RemoteServer
+ usersMu sync.RWMutex
+ serversMu sync.RWMutex
+ Storage Storage
}
type Config struct {
@@ -31,11 +29,10 @@ type Config struct {
func NewServer(cfg Config, storage Storage) *Server {
return &Server{
- cfg: cfg,
- Storage: storage,
- users: make(map[string]User),
- servers: make(map[string]RemoteServer),
- channels: make(map[string]*Channel),
+ cfg: cfg,
+ Storage: storage,
+ users: make(map[string]User),
+ servers: make(map[string]RemoteServer),
}
}
@@ -47,13 +44,6 @@ func (s *Server) Start() error {
return s.listenPlain()
}
-func (s *Server) AddChannel(name string) {
- s.channelsMu.Lock()
- defer s.channelsMu.Unlock()
- s.channels[name] = NewChannel(name)
- log.Println("created channel", name)
-}
-
func (s *Server) listenPlain() error {
ln, err := net.Listen("tcp", s.cfg.ListenAddr)
if err != nil {
diff --git a/server/storage.go b/server/storage.go
index 037c40f..1cb725f 100644
--- a/server/storage.go
+++ b/server/storage.go
@@ -16,4 +16,5 @@ type Storage interface {
SetPermission(data core.PermissionData) error
GetPermission(user, channel string) (core.PermissionData, error)
+ GetChannelUsers(channel string) ([]string, error)
}
diff --git a/server/user.go b/server/user.go
index 6ca857a..c615c40 100644
--- a/server/user.go
+++ b/server/user.go
@@ -9,16 +9,14 @@ import (
)
type User struct {
- Name string
- Conns map[net.Conn]struct{}
- Channels map[string]*Channel
+ Addr string
+ Conns map[net.Conn]struct{}
}
-func NewUser(conn net.Conn, name string) User {
+func NewUser(conn net.Conn, addr string) User {
u := User{
- Name: name,
- Conns: make(map[net.Conn]struct{}),
- Channels: make(map[string]*Channel),
+ Addr: addr,
+ Conns: make(map[net.Conn]struct{}),
}
u.Conns[conn] = struct{}{}
@@ -39,31 +37,31 @@ func (u *User) Send(senderConn net.Conn, payload core.Wrapper) error {
}
func (s *Server) handleUserConn(conn net.Conn) {
- name, err := s.performUserAuth(conn)
+ addr, err := s.performUserAuth(conn)
if err != nil {
log.Println("user auth error:", err)
return
}
s.usersMu.Lock()
- user, ok := s.users[name]
+ user, ok := s.users[addr]
if ok {
- log.Println("next connection from user:", user.Name)
+ log.Println("next connection from user:", user.Addr)
user.Conns[conn] = struct{}{}
} else {
- log.Println("initial connection from user:", name)
- user = NewUser(conn, name)
- s.users[name] = user
+ log.Println("initial connection from user:", addr)
+ user = NewUser(conn, addr)
+ s.users[addr] = user
}
s.usersMu.Unlock()
defer func() {
s.usersMu.Lock()
- log.Println("client disconnected: ", user.Name)
- delete(s.users[user.Name].Conns, conn)
- if len(s.users[user.Name].Conns) == 0 {
- log.Println("all connections closed for user:", user.Name)
- delete(s.users, user.Name)
+ log.Println("client disconnected: ", user.Addr)
+ delete(s.users[user.Addr].Conns, conn)
+ if len(s.users[user.Addr].Conns) == 0 {
+ log.Println("all connections closed for user:", user.Addr)
+ delete(s.users, user.Addr)
}
s.usersMu.Unlock()
}()
@@ -84,7 +82,7 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) {
return "", core.ErrUnexpectedPayloadType
}
- hash, err := s.Storage.GetUserPass(clientAuth.Name)
+ hash, err := s.Storage.GetUserPass(clientAuth.Addr)
if err != nil {
s.authFail(conn)
return "", core.ErrAuthInvalidUser
@@ -99,7 +97,7 @@ func (s *Server) performUserAuth(conn net.Conn) (string, error) {
return "", err
}
- return clientAuth.Name, nil
+ return clientAuth.Addr, nil
}
func (s *Server) authFail(conn net.Conn) {
@@ -123,7 +121,7 @@ func (s *Server) readUserInput(user *User, conn net.Conn) error {
func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) error {
switch v := payload.(type) {
case core.Message:
- return s.handleMessage(sender, core.ConnTypeUser, v)
+ return s.handleMessage(sender, core.ConnTypeUser, user, v)
case core.Usermode:
return s.handleUsermode(user, sender, v)
case core.History:
@@ -134,36 +132,25 @@ func (s *Server) handleUserPayload(user *User, sender net.Conn, payload any) err
}
func (s *Server) handleUsermode(user *User, conn net.Conn, mode core.Usermode) error {
- userAddr, err := core.ReadAddr(mode.UserAddr)
- if err != nil {
- return err
- }
-
- chanAddr, err := core.ReadAddr(mode.ChannelName)
- if err != nil {
- return err
- }
- if user.Name != userAddr.Channel {
- log.Println("unauthorized")
- return user.Send(conn, core.Error{core.ErrorUnauthorized})
- }
+ /*
+ userAddr, err := core.ReadAddr(mode.UserAddr)
+ if err != nil {
+ return err
+ }
- s.channelsMu.RLock()
- channel, ok := s.channels[chanAddr.Channel]
- if !ok {
- log.Println("not found", userAddr.Channel)
- return user.Send(conn, core.Error{core.ErrorNotFound})
- }
- s.channelsMu.RUnlock()
+ chanAddr, err := core.ReadAddr(mode.ChannelAddr)
+ if err != nil {
+ return err
+ }
- switch mode.Mode {
- case core.UsermodeNone:
- channel.Remove(user)
- case core.UsermodeInChannel:
- channel.Add(user)
- }
+ if user.Addr != userAddr.Channel {
+ log.Println("unauthorized")
+ return user.Send(conn, core.Error{core.ErrorUnauthorized})
+ }
+ // TODO: change user permissions here
+ */
return nil
}
@@ -177,7 +164,7 @@ func (s *Server) handleHistory(user *User, conn net.Conn, hist core.History) err
return user.Send(conn, core.Error{core.ErrorNotFound})
}
- perm, err := s.Storage.GetPermission(user.Name+"@"+s.cfg.Name, hist.Channel)
+ perm, err := s.Storage.GetPermission(user.Addr, hist.Channel)
if err != nil {
fmt.Println("cannot get message history:", err)
return user.Send(conn, core.Error{core.ErrorNotFound})