Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions cmd/client/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,16 @@ func (m *chatModel) refreshChannels(showMessage bool) {
return
}
needsRefresh := false
allEncrypted := true
for _, ch := range channels {
if _, err := m.ensureChannelKey(ch.ID); err != nil {
needsRefresh = true
}
name := m.decryptChannelName(ch.ID, ch.NameEnc)
if name == "<encrypted>" {
needsRefresh = true
} else {
allEncrypted = false
}
m.channels[ch.ID] = channelInfo{ID: ch.ID, Name: name}
}
Expand All @@ -583,6 +586,9 @@ func (m *chatModel) refreshChannels(showMessage bool) {
b.WriteString(")")
}
m.appendSystemMessage(b.String())
if allEncrypted {
m.appendSystemMessage("All channels are <encrypted> because no other online user has shared channel keys yet. They will decrypt once another user comes online and shares keys (or an admin re-shares them).")
}
}
m.ensureSidebarIndex()
m.channelRefreshNeeded = needsRefresh
Expand Down Expand Up @@ -1207,6 +1213,8 @@ func (m *chatModel) renderMessages() string {
}
if msg.isSystem {
sender = "system"
} else {
sender = formatUsername(sender)
}

var style lipgloss.Style
Expand Down Expand Up @@ -1252,7 +1260,7 @@ func (m *chatModel) renderSidebar() string {
style = sidebarOfflineStyle
}
}
name := entry.Name
name := formatUsername(entry.Name)
if entry.Admin {
name = fmt.Sprintf("%s (admin)", name)
}
Expand All @@ -1278,7 +1286,7 @@ func (m *chatModel) renderSidebar() string {
style = sidebarOfflineStyle
}
}
name := entry.Name
name := formatUsername(entry.Name)
if entry.Admin {
name = fmt.Sprintf("%s (admin)", name)
}
Expand Down Expand Up @@ -1637,7 +1645,7 @@ func (m chatModel) View() string {
header := fmt.Sprintf(
" %s %s %s %s",
appNameStyle.Render("* dialtone"),
headerStyle.Render(m.auth.Username),
headerStyle.Render(formatUsername(m.auth.Username)),
labelStyle.Render(shortID(m.auth.UserID)),
labelStyle.Render(m.activeChannelLabel()),
)
Expand Down Expand Up @@ -1702,6 +1710,17 @@ func shortID(id string) string {
return id
}

func formatUsername(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return name
}
if strings.HasPrefix(name, "<") && strings.HasSuffix(name, ">") {
return name
}
return "<" + name + ">"
}

func clampMin(v, minimum int) int {
if v < minimum {
return minimum
Expand Down
15 changes: 12 additions & 3 deletions cmd/client/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ type loginModel struct {
selectIndex int
}

const (
minUsernameLen = 2
maxUsernameLen = 20
)

func newLoginModel(defaultServer string) loginModel {
server := textinput.New()
server.Placeholder = "http://localhost:8080"
Expand All @@ -41,8 +46,8 @@ func newLoginModel(defaultServer string) loginModel {
server.Focus()

username := textinput.New()
username.Placeholder = "username"
username.CharLimit = 64
username.Placeholder = "username (2-20 chars)"
username.CharLimit = maxUsernameLen
username.Width = 30

password := textinput.New()
Expand Down Expand Up @@ -307,9 +312,13 @@ func (m loginModel) validateSubmit() string {
if strings.TrimSpace(m.serverURL()) == "" {
return "server url is required"
}
if m.username() == "" || m.password() == "" {
username := strings.TrimSpace(m.username())
if username == "" || m.password() == "" {
return "username and password are required"
}
if len(username) < minUsernameLen || len(username) > maxUsernameLen {
return fmt.Sprintf("username must be %d-%d characters", minUsernameLen, maxUsernameLen)
}
if len(m.passphrase()) < 8 {
return "keystore passphrase must be at least 8 characters"
}
Expand Down
14 changes: 14 additions & 0 deletions cmd/client/login_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"strings"
"testing"

tea "github.com/charmbracelet/bubbletea"
Expand All @@ -15,6 +16,19 @@ func TestLoginValidateSubmit(t *testing.T) {
t.Fatalf("unexpected error: %s", msg)
}

m.usernameInput.SetValue("a")
if msg := m.validateSubmit(); msg == "" {
t.Fatalf("expected username length error")
}

m.usernameInput.CharLimit = 0
m.usernameInput.SetValue(strings.Repeat("a", 21))
if msg := m.validateSubmit(); msg == "" {
t.Fatalf("expected username max length error")
}

m.usernameInput.SetValue("alice")

m.passphraseInp.SetValue("short")
if msg := m.validateSubmit(); msg == "" {
t.Fatalf("expected passphrase error")
Expand Down
23 changes: 19 additions & 4 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ var (
ErrTokenExpired = errors.New("token expired")
)

const (
minUsernameLen = 2
maxUsernameLen = 20
)

type Session struct {
Token string
UserID user.ID
Expand Down Expand Up @@ -59,8 +64,8 @@ func (s *Service) Register(ctx context.Context, username, password, publicKey, i
return user.User{}, device.Device{}, Session{}, errors.New("invites service is required")
}
name := normalizeUsername(username)
if name == "" {
return user.User{}, device.Device{}, Session{}, ErrInvalidInput
if err := validateUsername(name); err != nil {
return user.User{}, device.Device{}, Session{}, err
}
if err := validateRegisterPassword(password); err != nil {
return user.User{}, device.Device{}, Session{}, err
Expand Down Expand Up @@ -107,8 +112,8 @@ func (s *Service) Login(ctx context.Context, username, password, publicKey strin
return user.User{}, device.Device{}, Session{}, errors.New("services are required")
}
name := normalizeUsername(username)
if name == "" {
return user.User{}, device.Device{}, Session{}, ErrInvalidInput
if err := validateUsername(name); err != nil {
return user.User{}, device.Device{}, Session{}, err
}
if strings.TrimSpace(password) == "" || len(password) < 8 {
return user.User{}, device.Device{}, Session{}, ErrInvalidInput
Expand Down Expand Up @@ -219,6 +224,16 @@ func normalizeUsername(username string) string {
return strings.ToLower(strings.TrimSpace(username))
}

func validateUsername(name string) error {
if name == "" {
return ErrInvalidInput
}
if len(name) < minUsernameLen || len(name) > maxUsernameLen {
return fmt.Errorf("%w: username must be %d-%d characters", ErrInvalidInput, minUsernameLen, maxUsernameLen)
}
return nil
}

type tokenStore struct {
mu sync.Mutex
sessions map[string]Session
Expand Down
17 changes: 17 additions & 0 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ func TestRegister_EmptyUsername(t *testing.T) {
}
}

func TestRegister_ShortUsername(t *testing.T) {
svc := newTestService()
_, _, _, err := svc.Register(context.Background(), "a", "password123", "a2V5", "invite-1")
if !errors.Is(err, ErrInvalidInput) {
t.Fatalf("expected ErrInvalidInput, got %v", err)
}
}

func TestRegister_LongUsername(t *testing.T) {
svc := newTestService()
longName := strings.Repeat("a", 21)
_, _, _, err := svc.Register(context.Background(), longName, "password123", "a2V5", "invite-1")
if !errors.Is(err, ErrInvalidInput) {
t.Fatalf("expected ErrInvalidInput, got %v", err)
}
}

func TestLogin_Success(t *testing.T) {
svc := newTestService()
ctx := context.Background()
Expand Down
33 changes: 27 additions & 6 deletions internal/user/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import (

var ErrInvalidInput = errors.New("invalid input")

const (
minUsernameLen = 2
maxUsernameLen = 20
)

type Service struct {
repo Repository
idGen func() ID
Expand All @@ -38,8 +43,8 @@ func (s *Service) Create(ctx context.Context, username string) (User, error) {
}

name := normalizeUsername(username)
if name == "" {
return User{}, ErrInvalidInput
if err := validateUsername(name); err != nil {
return User{}, err
}
if len(s.pepper) == 0 {
return User{}, errors.New("username pepper is required")
Expand Down Expand Up @@ -67,7 +72,10 @@ func (s *Service) CreateWithPassword(ctx context.Context, username, passwordHash
}

name := normalizeUsername(username)
if name == "" || strings.TrimSpace(passwordHash) == "" {
if err := validateUsername(name); err != nil {
return User{}, err
}
if strings.TrimSpace(passwordHash) == "" {
return User{}, ErrInvalidInput
}
if len(s.pepper) == 0 {
Expand Down Expand Up @@ -98,7 +106,10 @@ func (s *Service) CreateWithPasswordAndID(ctx context.Context, id ID, username,
return User{}, ErrInvalidInput
}
name := normalizeUsername(username)
if name == "" || strings.TrimSpace(passwordHash) == "" {
if err := validateUsername(name); err != nil {
return User{}, err
}
if strings.TrimSpace(passwordHash) == "" {
return User{}, ErrInvalidInput
}
if len(s.pepper) == 0 {
Expand Down Expand Up @@ -147,8 +158,8 @@ func (s *Service) GetByUsername(ctx context.Context, username string) (User, err
return User{}, errors.New("repository is required")
}
name := normalizeUsername(username)
if name == "" {
return User{}, ErrInvalidInput
if err := validateUsername(name); err != nil {
return User{}, err
}
if len(s.pepper) == 0 {
return User{}, errors.New("username pepper is required")
Expand Down Expand Up @@ -208,6 +219,16 @@ func normalizeUsername(username string) string {
return strings.ToLower(strings.TrimSpace(username))
}

func validateUsername(name string) error {
if name == "" {
return ErrInvalidInput
}
if len(name) < minUsernameLen || len(name) > maxUsernameLen {
return ErrInvalidInput
}
return nil
}

func hashUsername(pepper []byte, username string) string {
mac := hmac.New(sha256.New, pepper)
_, _ = mac.Write([]byte(username))
Expand Down
20 changes: 20 additions & 0 deletions internal/user/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package user
import (
"context"
"errors"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -119,6 +120,25 @@ func TestCreate_WhitespaceUsername(t *testing.T) {
}
}

func TestCreate_ShortUsername(t *testing.T) {
svc, _ := newTestService()

_, err := svc.Create(context.Background(), "a")
if !errors.Is(err, ErrInvalidInput) {
t.Fatalf("expected ErrInvalidInput, got %v", err)
}
}

func TestCreate_LongUsername(t *testing.T) {
svc, _ := newTestService()
longName := strings.Repeat("a", 21)

_, err := svc.Create(context.Background(), longName)
if !errors.Is(err, ErrInvalidInput) {
t.Fatalf("expected ErrInvalidInput, got %v", err)
}
}

func TestCreate_TrimsWhitespace(t *testing.T) {
svc, _ := newTestService()

Expand Down
Loading