diff --git a/go.mod b/go.mod index 3674a7f..34d32cf 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module github.com/df-mc/go-xsapi -go 1.22.0 +go 1.23.0 + +require ( + github.com/coder/websocket v1.8.12 + github.com/google/uuid v1.6.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9498948 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/internal/attr.go b/internal/attr.go new file mode 100644 index 0000000..0a1d146 --- /dev/null +++ b/internal/attr.go @@ -0,0 +1,7 @@ +package internal + +import "log/slog" + +const errorKey = "error" + +func ErrAttr(err error) slog.Attr { return slog.Any(errorKey, err) } diff --git a/internal/transport.go b/internal/transport.go new file mode 100644 index 0000000..e664abe --- /dev/null +++ b/internal/transport.go @@ -0,0 +1,22 @@ +package internal + +import ( + "github.com/df-mc/go-xsapi" + "net/http" +) + +func SetTransport(client *http.Client, src xsapi.TokenSource) { + var ( + hasTransport bool + base = client.Transport + ) + if base != nil { + _, hasTransport = base.(*xsapi.Transport) + } + if !hasTransport { + client.Transport = &xsapi.Transport{ + Source: src, + Base: base, + } + } +} diff --git a/mpsd/activity.go b/mpsd/activity.go new file mode 100644 index 0000000..2ec79d1 --- /dev/null +++ b/mpsd/activity.go @@ -0,0 +1,158 @@ +package mpsd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/df-mc/go-xsapi" + "github.com/df-mc/go-xsapi/internal" + "github.com/google/uuid" + "net/http" + "net/url" + "strconv" + "time" +) + +// An ActivityFilter specifies a filter applied for searching activities on [ActivityFilter.Search] +type ActivityFilter struct { + // Client is a [http.Client] to be used to do HTTP requests. If nil, http.DefaultClient will be copied. + Client *http.Client + + // SocialGroup specifies a group that contains handles of activities. + SocialGroup string + // SocialGroupXUID references a user that does searching on specific SocialGroup. + SocialGroupXUID string +} + +func (f ActivityFilter) Search(src xsapi.TokenSource, serviceConfigID uuid.UUID) ([]ActivityHandle, error) { + if f.Client == nil { + f.Client = new(http.Client) + *f.Client = *http.DefaultClient + } + internal.SetTransport(f.Client, src) + + owners := make(map[string]any) + if f.SocialGroup != "" { + if f.SocialGroupXUID == "" { + tok, err := src.Token() + if err != nil { + return nil, fmt.Errorf("request token: %w", err) + } + f.SocialGroupXUID = tok.DisplayClaims().XUID + } + owners["people"] = map[string]any{ + "moniker": f.SocialGroup, + "monikerXuid": f.SocialGroupXUID, + } + } + + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(map[string]any{ + "type": "activity", + "scid": serviceConfigID, + "owners": owners, + }); err != nil { + return nil, fmt.Errorf("encode request body: %w", err) + } + req, err := http.NewRequest(http.MethodPost, searchURL.String(), buf) + if err != nil { + return nil, fmt.Errorf("make request: %w", err) + } + req.Header.Set("X-Xbl-Contract-Version", strconv.Itoa(contractVersion)) + + resp, err := f.Client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + var data struct { + Results []ActivityHandle `json:"results"` + } + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return nil, fmt.Errorf("decode response body: %w", err) + } + return data.Results, nil + default: + return nil, fmt.Errorf("%s %s: %s", req.Method, req.URL, resp.Status) + } +} + +func (conf PublishConfig) commitActivity(ctx context.Context, ref SessionReference) error { + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(activityHandle{ + Type: "activity", + SessionReference: ref, + Version: 1, + }); err != nil { + return fmt.Errorf("encode request body: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, handlesURL.String(), buf) + if err != nil { + return fmt.Errorf("make request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Xbl-Contract-Version", strconv.Itoa(contractVersion)) + + resp, err := conf.Client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK, http.StatusCreated: + return nil + default: + return fmt.Errorf("%s %s: %s", req.Method, req.URL, resp.Status) + } +} + +var ( + handlesURL = &url.URL{ + Scheme: "https", + Host: "sessiondirectory.xboxlive.com", + Path: "/handles", + } + + searchURL = &url.URL{ + Scheme: "https", + Host: "sessiondirectory.xboxlive.com", + Path: "/handles/query", + RawQuery: url.Values{ + "include": []string{"relatedInfo,customProperties"}, + }.Encode(), + } +) + +type activityHandle struct { + Type string `json:"type"` // Always "activity". + SessionReference SessionReference `json:"sessionRef,omitempty"` + Version int `json:"version"` // Always 1. + OwnerXUID string `json:"ownerXuid,omitempty"` +} + +type ActivityHandle struct { + activityHandle + CreateTime time.Time `json:"createTime,omitempty"` + CustomProperties json.RawMessage `json:"customProperties,omitempty"` + GameTypes json.RawMessage `json:"gameTypes,omitempty"` + ID uuid.UUID `json:"id,omitempty"` + InviteProtocol string `json:"inviteProtocol,omitempty"` + RelatedInfo *ActivityHandleRelatedInfo `json:"relatedInfo,omitempty"` + TitleID string `json:"titleId,omitempty"` +} + +type ActivityHandleRelatedInfo struct { + Closed bool `json:"closed,omitempty"` + InviteProtocol string `json:"inviteProtocol,omitempty"` + JoinRestriction string `json:"joinRestriction,omitempty"` + MaxMembersCount uint32 `json:"maxMembersCount,omitempty"` + PostedTime time.Time `json:"postedTime,omitempty"` + Visibility string `json:"visibility,omitempty"` +} + +const ( + SocialGroupPeople = "people" +) diff --git a/mpsd/commit.go b/mpsd/commit.go new file mode 100644 index 0000000..34b2bf2 --- /dev/null +++ b/mpsd/commit.go @@ -0,0 +1,89 @@ +package mpsd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/google/uuid" + "net/http" + "net/url" + "path" + "strconv" + "time" +) + +// Commit pushes a [SessionDescription] into the session, updating properties and other fields +// on the service. +func (s *Session) Commit(ctx context.Context, d *SessionDescription) (*Commit, error) { + return s.conf.commit(ctx, s.ref.URL(), d) +} + +// commit puts a [SessionDescription] on the URL. It is used for creating and updating the description +// of the Session. +func (conf PublishConfig) commit(ctx context.Context, u *url.URL, d *SessionDescription) (*Commit, error) { + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(d); err != nil { + return nil, fmt.Errorf("encode request body: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.String(), buf) + if err != nil { + return nil, fmt.Errorf("make request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Xbl-Contract-Version", strconv.Itoa(contractVersion)) + + resp, err := conf.Client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK, http.StatusCreated: + var commitment *Commit + if err := json.NewDecoder(resp.Body).Decode(&commitment); err != nil { + return nil, fmt.Errorf("decode response body: %w", err) + } + return commitment, nil + case http.StatusNoContent: + return nil, nil + default: + return nil, fmt.Errorf("%s %s: %s", req.Method, req.URL, resp.Status) + } +} + +// A SessionReference contains a reference to a Session. +type SessionReference struct { + ServiceConfigID uuid.UUID `json:"scid,omitempty"` + TemplateName string `json:"templateName,omitempty"` + Name string `json:"name,omitempty"` +} + +// URL returns the [url.URL] of the session referenced in SessionReference. +func (ref SessionReference) URL() *url.URL { + return &url.URL{ + Scheme: "https", + Host: "sessiondirectory.xboxlive.com", + Path: path.Join( + "/serviceconfigs/", ref.ServiceConfigID.String(), + "/sessionTemplates/", ref.TemplateName, + "/sessions/", ref.Name, + ), + } +} + +// Commit includes a [SessionDescription] returned as a response body from the service. +// It can be retrieved on [Session.Query], [Query], and [Session.Commit]. +type Commit struct { + ContractVersion uint32 `json:"contractVersion,omitempty"` + CorrelationID uuid.UUID `json:"correlationId,omitempty"` + SearchHandle uuid.UUID `json:"searchHandle,omitempty"` + Branch uuid.UUID `json:"branch,omitempty"` + ChangeNumber uint64 `json:"changeNumber,omitempty"` + StartTime time.Time `json:"startTime,omitempty"` + NextTimer time.Time `json:"nextTimer,omitempty"` + + *SessionDescription +} + +const contractVersion = 107 diff --git a/mpsd/handler.go b/mpsd/handler.go new file mode 100644 index 0000000..e33d432 --- /dev/null +++ b/mpsd/handler.go @@ -0,0 +1,31 @@ +package mpsd + +import "github.com/google/uuid" + +// Handler notifies that a Session has been changed. It is called by the handler of +// *rta.Subscription contracted with *rta.Conn on [PublishConfig.PublishContext]. +type Handler interface { + // HandleSessionChange handles a change of session. The latest state of Session can be + // retrieved via [Session.Query]. + HandleSessionChange(ref SessionReference, branch uuid.UUID, changeNumber uint64) +} + +// A NopHandler implements a no-op Handler, which does nothing. +type NopHandler struct{} + +func (NopHandler) HandleSessionChange(SessionReference, uuid.UUID, uint64) {} + +// Handle stores a Handler into the Session atomically, which notifies events that may occur +// in the *rta.Subscription of the Session. If Handler is a nil, a NopHandler will be stored +// instead. +func (s *Session) Handle(h Handler) { + if h == nil { + h = NopHandler{} + } + s.h.Store(&h) +} + +// handler returns the Handler of the Session. It is usually called to handle events that may occur. +func (s *Session) handler() Handler { + return *s.h.Load() +} diff --git a/mpsd/invite.go b/mpsd/invite.go new file mode 100644 index 0000000..6886bef --- /dev/null +++ b/mpsd/invite.go @@ -0,0 +1,68 @@ +package mpsd + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/google/uuid" + "net/http" + "strconv" + "time" +) + +// Invite sends an invitation into the user referenced by XUID. The ID of the title which has sent +// an invitation is required to call this method. An InviteHandle may be returned. +func (s *Session) Invite(xuid string, titleID int32) (*InviteHandle, error) { + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(&inviteHandle{ + Type: "invite", + SessionReference: s.ref, + Version: 1, + InvitedXUID: xuid, + InviteAttributes: map[string]any{ + "titleId": strconv.FormatInt(int64(titleID), 10), + }, + }); err != nil { + return nil, fmt.Errorf("encode request body: %w", err) + } + req, err := http.NewRequest(http.MethodPost, handlesURL.String(), buf) + if err != nil { + return nil, fmt.Errorf("make request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Xbl-Contract-Version", strconv.Itoa(contractVersion)) + + resp, err := s.conf.Client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusCreated: + // It seems the C++ implementation only decodes "id" field from the response. + var handle *InviteHandle + if err := json.NewDecoder(resp.Body).Decode(&handle); err != nil { + return nil, fmt.Errorf("decode response body: %w", err) + } + return handle, nil + default: + return nil, fmt.Errorf("%s %s: %s", req.Method, req.URL, resp.Status) + } +} + +type inviteHandle struct { + Type string `json:"type,omitempty"` // Always "invite". + Version int `json:"version,omitempty"` // Always 1. + InviteAttributes map[string]any `json:"inviteAttributes,omitempty"` + InvitedXUID string `json:"invitedXuid,omitempty"` + SessionReference SessionReference `json:"sessionRef,omitempty"` +} + +type InviteHandle struct { + inviteHandle + Expiration time.Time `json:"expiration,omitempty"` + ID uuid.UUID `json:"id,omitempty"` + InviteProtocol string `json:"inviteProtocol,omitempty"` + SenderXUID string `json:"senderXuid,omitempty"` + GameTypes json.RawMessage `json:"gameTypes,omitempty"` +} diff --git a/mpsd/join.go b/mpsd/join.go new file mode 100644 index 0000000..239db81 --- /dev/null +++ b/mpsd/join.go @@ -0,0 +1,23 @@ +package mpsd + +import ( + "context" + "github.com/df-mc/go-xsapi" + "github.com/google/uuid" +) + +// JoinConfig implements methods for joining a Session from several handles. It also includes +// a PublishConfig to publish a SessionDescription into the URL referenced in a handle. +type JoinConfig struct { + PublishConfig +} + +// JoinHandleContext joins a Session from the ID of handle and a reference to it. +func (conf JoinConfig) JoinHandleContext(ctx context.Context, src xsapi.TokenSource, handleID uuid.UUID, ref SessionReference) (*Session, error) { + return conf.publish(ctx, src, handlesURL.JoinPath(handleID.String(), "session"), ref) +} + +// JoinActivityContext joins a Session from ActivityHandle. +func (conf JoinConfig) JoinActivityContext(ctx context.Context, src xsapi.TokenSource, handle ActivityHandle) (*Session, error) { + return conf.JoinHandleContext(ctx, src, handle.ID, handle.SessionReference) +} diff --git a/mpsd/member.go b/mpsd/member.go new file mode 100644 index 0000000..2225486 --- /dev/null +++ b/mpsd/member.go @@ -0,0 +1,61 @@ +package mpsd + +import ( + "encoding/json" + "github.com/google/uuid" +) + +// MemberDescription represents a read only reference to member in a multiplayer session. +type MemberDescription struct { + Constants *MemberConstants `json:"constants,omitempty"` + Properties *MemberProperties `json:"properties,omitempty"` + Roles json.RawMessage `json:"roles,omitempty"` +} + +type MemberProperties struct { + System *MemberPropertiesSystem `json:"system,omitempty"` + Custom json.RawMessage `json:"custom,omitempty"` +} + +type MemberPropertiesSystem struct { + Active bool `json:"active,omitempty"` + Ready bool `json:"ready,omitempty"` + Connection uuid.UUID `json:"connection,omitempty"` + Subscription *MemberPropertiesSystemSubscription `json:"subscription,omitempty"` + SecureDeviceAddress []byte `json:"secureDeviceAddress,omitempty"` + InitializationGroup []uint32 `json:"initializationGroup,omitempty"` + Groups []string `json:"groups,omitempty"` + Encounters []string `json:"encounters,omitempty"` + Measurements json.RawMessage `json:"measurements,omitempty"` + ServerMeasurements json.RawMessage `json:"serverMeasurements,omitempty"` +} + +type MemberPropertiesSystemSubscription struct { + ID string `json:"id,omitempty"` + // ChangeTypes defines values that indicate change types for a multiplayer session. + ChangeTypes []string `json:"changeTypes,omitempty"` +} + +const ( + ChangeTypeEverything = "everything" + ChangeTypeHost = "host" + ChangeTypeInitialization = "initialization" + ChangeTypeMatchmakingStatus = "matchmakingStatus" + ChangeTypeMembersList = "membersList" + ChangeTypeMembersStatus = "membersStatus" + ChangeTypeJoinability = "joinability" + ChangeTypeCustomProperty = "customProperty" + ChangeTypeMembersCustomProperty = "membersCustomProperty" +) + +type MemberConstants struct { + System *MemberConstantsSystem `json:"system,omitempty"` + // Custom is a JSON string that specify the custom constants for the member. + Custom json.RawMessage `json:"custom,omitempty"` +} + +type MemberConstantsSystem struct { + // XUID is the user ID of the member. Only known if the member has accepted. + XUID string `json:"xuid,omitempty"` + Initialize bool `json:"initialize,omitempty"` +} diff --git a/mpsd/publish.go b/mpsd/publish.go new file mode 100644 index 0000000..461a287 --- /dev/null +++ b/mpsd/publish.go @@ -0,0 +1,126 @@ +package mpsd + +import ( + "context" + "encoding/json" + "fmt" + "github.com/df-mc/go-xsapi" + "github.com/df-mc/go-xsapi/internal" + "github.com/df-mc/go-xsapi/rta" + "github.com/google/uuid" + "log/slog" + "net/http" + "net/url" + "strings" +) + +// PublishConfig contains an options for publishing a SessionDescription into a Session. +type PublishConfig struct { + RTADialer *rta.Dialer + RTAConn *rta.Conn + + Description *SessionDescription + + Client *http.Client + Logger *slog.Logger +} + +func (conf PublishConfig) publish(ctx context.Context, src xsapi.TokenSource, u *url.URL, ref SessionReference) (*Session, error) { + if conf.Logger == nil { + conf.Logger = slog.Default() + } + if conf.Client == nil { + conf.Client = &http.Client{} + } + internal.SetTransport(conf.Client, src) + + if conf.RTAConn == nil { + if conf.RTADialer == nil { + conf.RTADialer = &rta.Dialer{} + } + var err error + conf.RTAConn, err = conf.RTADialer.DialContext(ctx, src) + if err != nil { + return nil, fmt.Errorf("prepare subscription: dial: %w", err) + } + } + + sub, err := conf.RTAConn.Subscribe(ctx, resourceURI) + if err != nil { + return nil, fmt.Errorf("prepare subscription: subscribe: %w", err) + } + var custom subscription + if err := json.Unmarshal(sub.Custom, &custom); err != nil { + return nil, fmt.Errorf("prepare subscription: decode: %w", err) + } + + if conf.Description == nil { + conf.Description = &SessionDescription{} + } + if conf.Description.Members == nil { + conf.Description.Members = make(map[string]*MemberDescription, 1) + } + + if ref.Name == "" { + ref.Name = strings.ToUpper(uuid.NewString()) + } + + me, ok := conf.Description.Members["me"] + if !ok { + me = &MemberDescription{} + } + if me.Constants == nil { + me.Constants = &MemberConstants{} + } + if me.Constants.System == nil { + me.Constants.System = &MemberConstantsSystem{} + } + me.Constants.System.Initialize = true + if me.Constants.System.XUID == "" { + tok, err := src.Token() + if err != nil { + return nil, fmt.Errorf("obtain token: %w", err) + } + me.Constants.System.XUID = tok.DisplayClaims().XUID + } + if me.Properties == nil { + me.Properties = &MemberProperties{} + } + if me.Properties.System == nil { + me.Properties.System = &MemberPropertiesSystem{} + } + me.Properties.System.Active = true + me.Properties.System.Connection = custom.ConnectionID + if me.Properties.System.Subscription == nil { + me.Properties.System.Subscription = &MemberPropertiesSystemSubscription{} + } + if me.Properties.System.Subscription.ID == "" { + me.Properties.System.Subscription.ID = strings.ToUpper(uuid.NewString()) + } + me.Properties.System.Subscription.ChangeTypes = []string{ + ChangeTypeEverything, + } + conf.Description.Members["me"] = me + + if _, err := conf.commit(ctx, u, conf.Description); err != nil { + return nil, fmt.Errorf("commit: %w", err) + } + if err := conf.commitActivity(ctx, ref); err != nil { + return nil, fmt.Errorf("commit activity: %w", err) + } + + s := &Session{ + ref: ref, + conf: conf, + rta: conf.RTAConn, + sub: sub, + } + s.Handle(nil) + sub.Handle(&subscriptionHandler{s}) + return s, nil +} + +// PublishContext publishes a Session on the SessionReference using the [context.Context]. +func (conf PublishConfig) PublishContext(ctx context.Context, src xsapi.TokenSource, ref SessionReference) (s *Session, err error) { + return conf.publish(ctx, src, ref.URL(), ref) +} diff --git a/mpsd/query.go b/mpsd/query.go new file mode 100644 index 0000000..8f01188 --- /dev/null +++ b/mpsd/query.go @@ -0,0 +1,45 @@ +package mpsd + +import ( + "encoding/json" + "fmt" + "github.com/df-mc/go-xsapi" + "github.com/df-mc/go-xsapi/internal" + "net/http" + "strconv" +) + +type Query struct { + Client *http.Client +} + +// Query retrieves the Commit of a session referenced in SessionReference. +func (q Query) Query(src xsapi.TokenSource, ref SessionReference) (*Commit, error) { + if q.Client == nil { + q.Client = &http.Client{} + } + internal.SetTransport(q.Client, src) + + req, err := http.NewRequest(http.MethodGet, ref.URL().String(), nil) + if err != nil { + return nil, fmt.Errorf("make request: %w", err) + } + req.Header.Set("X-Xbl-Contract-Version", strconv.Itoa(contractVersion)) + + resp, err := q.Client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + var c *Commit + if err := json.NewDecoder(resp.Body).Decode(&c); err != nil { + return nil, fmt.Errorf("decode response body: %w", err) + } + return c, nil + default: + return nil, fmt.Errorf("%s %s: %s", req.Method, req.URL, resp.Status) + } +} diff --git a/mpsd/session.go b/mpsd/session.go new file mode 100644 index 0000000..8846e5b --- /dev/null +++ b/mpsd/session.go @@ -0,0 +1,235 @@ +package mpsd + +import ( + "context" + "encoding/json" + "github.com/df-mc/go-xsapi/rta" + "sync/atomic" +) + +type Session struct { + ref SessionReference + conf PublishConfig + + rta *rta.Conn + + sub *rta.Subscription + + h atomic.Pointer[Handler] +} + +func (s *Session) Query() (*Commit, error) { + q := Query{Client: s.conf.Client} + return q.Query(nil, s.ref) +} + +func (s *Session) Close() error { + if err := s.rta.Unsubscribe(context.Background(), s.sub); err != nil { + s.conf.Logger.Error("error unsubscribing with RTA", "err", err) + } + _, err := s.Commit(context.Background(), &SessionDescription{ + Members: map[string]*MemberDescription{ + "me": nil, + }, + }) + return err +} + +type SessionDescription struct { + Constants *SessionConstants `json:"constants,omitempty"` + RoleTypes json.RawMessage `json:"roleTypes,omitempty"` + Properties *SessionProperties `json:"properties,omitempty"` + Members map[string]*MemberDescription `json:"members,omitempty"` +} + +// SessionProperties is a set of properties associated with multiplayer session. +// Any member can modify these fields. +type SessionProperties struct { + System *SessionPropertiesSystem `json:"system,omitempty"` + // Custom is a JSON string that specify the custom properties for the session. These can + // be changed anytime. + Custom json.RawMessage `json:"custom,omitempty"` +} + +type SessionPropertiesSystem struct { + // Keywords is an optional list of keywords associated with the session. + Keywords []string `json:"keywords,omitempty"` + // Turn is a list of member IDs indicating whose turn it is. + Turn []uint32 `json:"turn,omitempty"` + // JoinRestriction restricts who can join "open" sessions. (Has no effects on reservations, + // which means it has no impact on "private" and "visible" sessions) + // It is one of constants defined below. + JoinRestriction string `json:"joinRestriction,omitempty"` + // ReadRestriction restricts who can read "open" sessions. (Has no effect on reservations, + // which means it has no impact on "private" and "visible" sessions.) + ReadRestriction string `json:"readRestriction,omitempty"` + // Controls whether a session is joinable, independent of visibility, join restriction, + // and available space in the session. Does not affect reservations. Defaults to false. + Closed bool `json:"closed"` + // If Locked is true, it would allow the members of the session to be locked, such that + // if a user leaves they are able to come back into the session but no other user could + // take that spot. Defaults to false. + Locked bool `json:"locked,omitempty"` + Matchmaking *SessionPropertiesSystemMatchmaking `json:"matchmaking,omitempty"` + // MatchmakingResubmit is true, if the match that was found didn't work out and needs to + // be resubmitted. If false, signal that the match did work, and the matchmaking service + // can release the session. + MatchmakingResubmit bool `json:"matchmakingResubmit,omitempty"` + // InitializationSucceeded is true if initialization succeeded. + InitializationSucceeded bool `json:"initializationSucceeded,omitempty"` + // Host is the device token of the host. + Host string `json:"host,omitempty"` + // ServerConnectionStringCandidates is the ordered list of case-insensitive connection + // strings that the session could use to connect to a game server. Generally titles + // should use the first on the list, but sophisticated titles could use a custom mechanism + // for choosing one of the others (e.g. based on load). + ServerConnectionStringCandidates json.RawMessage `json:"serverConnectionStringCandidates,omitempty"` +} + +type SessionPropertiesSystemMatchmaking struct { + // TargetSessionConstants is a JSON string representing the target session constants. + TargetSessionConstants json.RawMessage `json:"targetSessionConstants,omitempty"` + // ServerConnectionString Force a specific connection string to be used. This is useful + // for session in progress join scenarios. + ServerConnectionString string `json:"serverConnectionString,omitempty"` +} + +const ( + SessionRestrictionNone = "none" + SessionRestrictionLocal = "local" + SessionRestrictionFollowed = "followed" +) + +// SessionConstants represents constants for a multiplayer session. +// +// SessionConstants are set by the creator or by the session template only when a +// session is created. Fields in SessionConstants generally cannot be changed after +// the session is created. +type SessionConstants struct { + System *SessionConstantsSystem `json:"system,omitempty"` + // Custom is any custom constants for the session, specified in a JSON string. + Custom json.RawMessage `json:"custom,omitempty"` +} + +type SessionConstantsSystem struct { + // MaxMembersCount is the maximum number of members in the session. + MaxMembersCount uint32 `json:"maxMembersCount,omitempty"` + // Capabilities is the capabilities of the session. + Capabilities *SessionCapabilities `json:"capabilities,omitempty"` + // Visibility is the visibility of the session. + Visibility string `json:"visibility,omitempty"` + // Initiators is a list of XUIDs indicating who initiated the session. + Initiators []string `json:"initiators,omitempty"` + // ReservedRemovalTimeout is the maximum time, in milliseconds, for a member with a reservation + // to join the session. If the member doesn't join within this time, this reservation is removed. + ReservedRemovalTimeout uint64 `json:"reservedRemovalTimeout,omitempty"` + // InactiveRemovalTimeout is the maximum time, in milliseconds, for an inactive member to become + // active. If an inactive member doesn't become active within this time, the member is removed from + // the session. + InactiveRemovalTimeout uint64 `json:"inactiveRemovalTimeout,omitempty"` + // ReadyRemovalTimeout is the maximum time, in milliseconds, for a member who is marked as ready + // to become active. When the shell launches the title to start a multiplayer game, the member is + // marked as ready. If a member who is marked as ready doesn't become active with in this time, + // the member becomes inactive. + ReadyRemovalTimeout uint64 `json:"readyRemovalTimeout,omitempty"` + // SessionEmptyTimeout is the maximum time, in milliseconds, that the session can remain empty. + // If no members join the session within this time, the session is deleted. + SessionEmptyTimeout uint64 `json:"sessionEmptyTimeout,omitempty"` + Metrics *SessionConstantsSystemMetrics `json:"metrics,omitempty"` + // If MemberInitialization is set, the session expects the client system or title to perform initialization + // after session creation. Timeouts and initialization stages are automatically tracked by the session, including + // initial Quality of Service (QoS) measurements if any metrics are set. + MemberInitialization *MemberInitialization `json:"memberInitialization,omitempty"` + // PeerToPeerRequirements is a QoS requirements for a connection between session members. + PeerToPeerRequirements *PeerToPeerRequirements `json:"peerToPeerRequirements,omitempty"` + // PeerToHostRequirements is a QoS requirements for a connection between a host candidate + // and session members. + PeerToHostRequirements *PeerToHostRequirements `json:"peerToHostRequirements,omitempty"` + // MeasurementServerAddresses is the set of potential server connection strings that should + // be evaluated. + MeasurementServerAddresses json.RawMessage `json:"measurementServerAddresses,omitempty"` + // CloudComputePackage is the Cloud Compute package constants for the session, specified in a JSON string. + CloudComputePackage json.RawMessage `json:"cloudComputePackage,omitempty"` +} + +type PeerToHostRequirements struct { + LatencyMaximum uint64 `json:"latencyMaximum,omitempty"` + BandwidthDownMinimum uint64 `json:"bandwidthDownMinimum,omitempty"` + BandwidthUpMinimum uint64 `json:"bandwidthUpMinimum,omitempty"` + HostSelectionMetric string `json:"hostSelectionMetric,omitempty"` +} + +const ( + HostSelectionMetricBandwidthUp = "bandwidthUp" + HostSelectionMetricBandwidthDown = "bandwidthDown" + HostSelectionMetricBandwidth = "bandwidth" + HostSelectionMetricLatency = "latency" +) + +type PeerToPeerRequirements struct { + LatencyMaximum uint64 `json:"latencyMaximum,omitempty"` + BandwidthMinimum uint64 `json:"bandwidthMinimum,omitempty"` +} + +type MemberInitialization struct { + JoinTimeout uint64 `json:"joinTimeout,omitempty"` + MeasurementTimeout uint64 `json:"measurementTimeout,omitempty"` + EvaluationTimeout uint64 `json:"evaluationTimeout,omitempty"` + ExternalEvaluation bool `json:"externalEvaluation,omitempty"` + MembersNeededToStart uint32 `json:"membersNeededToStart,omitempty"` +} + +type SessionConstantsSystemMetrics struct { + // Latency indicates that the title wants latency measured to + // help determine connectivity. + Latency bool `json:"latency,omitempty"` + // Bandwidth indicates that the title wants downstream (host-to-peer) + // bandwidth measured to help determine connectivity. + BandwidthDown bool `json:"bandwidthDown,omitempty"` + // BandwidthUp indicates that the title wants upstream (peer-to-host) + // bandwidth measured to help determine connectivity. + BandwidthUp bool `json:"bandwidthUp,omitempty"` + // Custom indicates that the title wants a custom measurement to help + // determine connectivity. + Custom bool `json:"custom,omitempty"` +} + +// SessionCapabilities represents the capabilities of multiplayer session. +// +// SessionCapabilities are optional bool values that are set in the session +// template. If no capabilities are needed, an empty SessionCapabilities should +// be used in the template to prevent capabilities from being specified at session +// creation, unless the title requires dynamic session capabilities. +type SessionCapabilities struct { + // Connectivity indicates whether a session can enable metrics. + Connectivity bool `json:"connectivity,omitempty"` + // If SuppressPresenceActivityCheck is false (the default value), active users are required to + // remain online playing the title. If they don't, they are demoted to inactive status. Set + // SuppressPresenceActivityCheck to true to enable session members to stay active indefinitely + SuppressPresenceActivityCheck bool `json:"suppressPresenceActivityCheck,omitempty"` + // Gameplay indicates whether the session represents actual gameplay rather than time in setup + // or a menu, such as a lobby or during matchmaking. + Gameplay bool `json:"gameplay,omitempty"` + // If Large is true, if the session can host 101 to 1000 users, which affects other session features. + // Otherwise, the session can host 1 to 100 users. + Large bool `json:"large,omitempty"` + // If UserAuthorizationStyle is true, the session supports calls from platforms without strong + // title identity. This capability can't be set on large sessions. + UserAuthorizationStyle bool `json:"userAuthorizationStyle,omitempty"` + // If ConnectionRequiredForActiveMembers is true, a connection is required for a member to be + // marked as active. To enable session notifications and detect disconnections, it must be set + // to true. + ConnectionRequiredForActiveMembers bool `json:"connectionRequiredForActiveMembers,omitempty"` + // CrossPlay is true, if the session supports crossplay. + CrossPlay bool `json:"crossPlay,omitempty"` + // If Searchable is true, the session can be linked to a search handle for searching. + Searchable bool `json:"searchable,omitempty"` + // If HasOwners is true, the session has owners. + HasOwners bool `json:"hasOwners,omitempty"` +} + +const ( + SessionVisibilityPrivate = "private" + SessionVisibilityVisible = "visible" + SessionVisibilityOpen = "open" +) diff --git a/mpsd/subscription.go b/mpsd/subscription.go new file mode 100644 index 0000000..966e39c --- /dev/null +++ b/mpsd/subscription.go @@ -0,0 +1,58 @@ +package mpsd + +import ( + "encoding/json" + "fmt" + "github.com/df-mc/go-xsapi/internal" + "github.com/google/uuid" + "strings" +) + +const resourceURI = "https://sessiondirectory.xboxlive.com/connections/" + +type subscription struct { + ConnectionID uuid.UUID `json:"ConnectionId,omitempty"` +} + +type subscriptionHandler struct { + *Session +} + +func (h *subscriptionHandler) HandleEvent(data json.RawMessage) { + var event subscriptionEvent + if err := json.Unmarshal(data, &event); err != nil { + h.conf.Logger.Error("error decoding subscription event", internal.ErrAttr(err)) + } + for _, tap := range event.ShoulderTaps { + ref, err := h.parseReference(tap.Resource) + if err != nil { + h.conf.Logger.Error("handle subscription event: error parsing shoulder tap", internal.ErrAttr(err)) + continue + } + h.handler().HandleSessionChange(ref, tap.Branch, tap.ChangeNumber) + } +} + +func (h *subscriptionHandler) parseReference(s string) (ref SessionReference, err error) { + segments := strings.Split(s, "~") + if len(segments) != 3 { + return ref, fmt.Errorf("unexpected segmentations: %s", s) + } + ref.ServiceConfigID, err = uuid.Parse(segments[0]) + if err != nil { + return ref, fmt.Errorf("parse service config ID: %w", err) + } + ref.TemplateName = segments[1] + ref.Name = segments[2] + return ref, nil +} + +type subscriptionEvent struct { + ShoulderTaps []shoulderTap `json:"shoulderTaps"` +} + +type shoulderTap struct { + Resource string `json:"resource"` + ChangeNumber uint64 `json:"changeNumber"` + Branch uuid.UUID `json:"branch"` +} diff --git a/rta/conn.go b/rta/conn.go new file mode 100644 index 0000000..7194851 --- /dev/null +++ b/rta/conn.go @@ -0,0 +1,263 @@ +package rta + +import ( + "context" + "encoding/json" + "fmt" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/df-mc/go-xsapi/internal" + "log/slog" + "net" + "sync" + "sync/atomic" +) + +// Conn represents a connection between the real-time activity services. It can +// be established from Dialer with an authorization token that relies on the +// party 'https://xboxlive.com/'. +// +// A Conn controls subscriptions real-timely under a websocket connection. An +// index-specific JSON array is used for the communication. Conn is safe for +// concurrent use in multiple goroutines. +// +// SubscriptionHandlers are useful to handle any events that may occur in the subscriptions +// controlled by Conn, and can be stored atomically to a Subscription from [Subscription.Handle]. +type Conn struct { + conn *websocket.Conn + + sequences [operationCapacity]atomic.Uint32 + expected [operationCapacity]map[uint32]chan<- *handshake + expectedMu sync.RWMutex + + subscriptions map[uint32]*Subscription + subscriptionsMu sync.RWMutex + + log *slog.Logger + + once sync.Once + closed chan struct{} +} + +// Subscribe attempts to subscribe with the specific resource URI, with the [context.Context] +// to be used during the handshake. A Subscription may be returned, which contains an ID +// and Custom data as the result of handshake. +func (c *Conn) Subscribe(ctx context.Context, resourceURI string) (*Subscription, error) { + sequence := c.sequences[operationSubscribe].Add(1) + hand, err := c.shake(operationSubscribe, sequence, []any{resourceURI}) + if err != nil { + return nil, err + } + defer c.release(operationSubscribe, sequence) + select { + case h := <-hand: + switch h.status { + case StatusOK: + if len(h.payload) < 2 { + return nil, &OutOfRangeError{ + Payload: h.payload, + Index: 1, + } + } + sub := &Subscription{} + if err := json.Unmarshal(h.payload[0], &sub.ID); err != nil { + return nil, fmt.Errorf("decode subscription ConnectionID: %w", err) + } + sub.Custom = h.payload[1] + + c.subscriptionsMu.Lock() + c.subscriptions[sub.ID] = sub + c.subscriptionsMu.Unlock() + return sub, nil + default: + return nil, unexpectedStatusCode(h.status, h.payload) + } + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.closed: + return nil, net.ErrClosed + } +} + +// Unsubscribe attempts to unsubscribe with a Subscription associated with an ID, with +// the [context.Context] to be used during the handshake. An error may be returned. +func (c *Conn) Unsubscribe(ctx context.Context, sub *Subscription) error { + sequence := c.sequences[operationUnsubscribe].Add(1) + hand, err := c.shake(operationUnsubscribe, sequence, []any{sub.ID}) + if err != nil { + return err + } + defer c.release(operationUnsubscribe, sequence) + select { + case h := <-hand: + if h.status != StatusOK { + return unexpectedStatusCode(h.status, h.payload) + } + return nil + case <-ctx.Done(): + return ctx.Err() + case <-c.closed: + return net.ErrClosed + } +} + +// Subscription represents a subscription contracted with the resource URI available through +// the real-time activity service. A Subscription may be contracted via Conn.Subscribe. +type Subscription struct { + ID uint32 + Custom json.RawMessage + + h SubscriptionHandler + mu sync.Mutex +} + +func (s *Subscription) Handle(h SubscriptionHandler) { + s.mu.Lock() + s.h = h + s.mu.Unlock() +} + +func (s *Subscription) handler() SubscriptionHandler { + s.mu.Lock() + defer s.mu.Unlock() + if s.h == nil { + return NopSubscriptionHandler{} + } + return s.h +} + +type SubscriptionHandler interface { + HandleEvent(custom json.RawMessage) +} + +type NopSubscriptionHandler struct{} + +func (NopSubscriptionHandler) HandleEvent(json.RawMessage) {} + +// write attempts to write a JSON array with header and the body. A background context is +// used as no context perceived by the parent goroutine should be used to a websocket method +// to avoid closing the connection if it has cancelled or exceeded a deadline. +func (c *Conn) write(typ uint32, payload []any) error { + return wsjson.Write(context.Background(), c.conn, append([]any{typ}, payload...)) +} + +// read goes as a background goroutine of Conn, reading a JSON array from the websocket +// connection and decoding a header needed to indicate which message should be handled. +func (c *Conn) read() { + for { + var payload []json.RawMessage + if err := wsjson.Read(context.Background(), c.conn, &payload); err != nil { + _ = c.Close() + return + } + typ, err := readHeader(payload) + if err != nil { + c.log.Error("error reading header", internal.ErrAttr(err)) + continue + } + go c.handleMessage(typ, payload[1:]) + } +} + +// Close closes the websocket connection with websocket.StatusNormalClosure. +func (c *Conn) Close() (err error) { + c.once.Do(func() { + close(c.closed) + err = c.conn.Close(websocket.StatusNormalClosure, "") + }) + return err +} + +// handleMessage handles a message received in read with the type. +func (c *Conn) handleMessage(typ uint32, payload []json.RawMessage) { + switch typ { + case typeSubscribe, typeUnsubscribe: // Subscribe & Unsubscribe handshake response + h, err := readHandshake(payload) + if err != nil { + c.log.Error("error reading handshake response", internal.ErrAttr(err)) + return + } + op := typeToOperation(typ) + c.expectedMu.RLock() + defer c.expectedMu.RUnlock() + hand, ok := c.expected[op][h.sequence] + if !ok { + c.log.Debug("unexpected handshake response", slog.Group("message", "type", typ, "sequence", h.sequence)) + return + } + hand <- h + case typeEvent: + if len(payload) < 2 { + c.log.Debug("event message has no custom") + return + } + var subscriptionID uint32 + if err := json.Unmarshal(payload[0], &subscriptionID); err != nil { + c.log.Error("error decoding subscription ID", internal.ErrAttr(err)) + } + c.subscriptionsMu.Lock() + defer c.subscriptionsMu.Unlock() + sub, ok := c.subscriptions[subscriptionID] + if ok { + go sub.handler().HandleEvent(payload[1]) + } + c.log.Debug("received event", slog.Group("message", "type", typ, "custom", payload[0])) + default: + c.log.Debug("received an unexpected message", slog.Group("message", "type", typ)) + } +} + +// An OutOfRangeError occurs when reading values from payload received from the service. +// The Payload specifies the remaining values included in the payload, and the Index specifies +// a length of values that is missing from the payload. +type OutOfRangeError struct { + Payload []json.RawMessage + Index int +} + +func (e *OutOfRangeError) Error() string { + return fmt.Sprintf("xsapi/rta: index out of range [%d] with length %d", e.Index, len(e.Payload)) +} + +// readHeader decodes a header from the first 1 value from the payload. An OutOfRangeError +// may be returned if the payload has not enough length to read. +func readHeader(payload []json.RawMessage) (typ uint32, err error) { + if len(payload) < 1 { + return typ, &OutOfRangeError{ + Payload: payload, + Index: 0, + } + } + return typ, json.Unmarshal(payload[0], &typ) +} + +// readHandshake decodes a handshake from the first 2 values from the payload. +// An OutOfRangeError may be returned if the payload has not enough length to read. +func readHandshake(payload []json.RawMessage) (*handshake, error) { + if len(payload) < 2 { + return nil, &OutOfRangeError{ + Payload: payload, + Index: 2, + } + } + h := &handshake{} + if err := json.Unmarshal(payload[0], &h.sequence); err != nil { + return nil, fmt.Errorf("decode sequence: %w", err) + } + if err := json.Unmarshal(payload[1], &h.status); err != nil { + return nil, fmt.Errorf("decode status code: %w", err) + } + h.payload = payload[2:] + return h, nil +} + +// unexpectedStatusCode wraps an UnexpectedStatusError from the status. +// If the payload has more than one remaining values, it will try to decode +// them as an error message. +func unexpectedStatusCode(status int32, payload []json.RawMessage) error { + err := &UnexpectedStatusError{Code: status} + if len(payload) >= 1 { + _ = json.Unmarshal(payload[0], &err.Message) + } + return err +} diff --git a/rta/dial.go b/rta/dial.go new file mode 100644 index 0000000..3d4ccc6 --- /dev/null +++ b/rta/dial.go @@ -0,0 +1,71 @@ +package rta + +import ( + "context" + "github.com/coder/websocket" + "github.com/df-mc/go-xsapi" + "github.com/df-mc/go-xsapi/internal" + "log/slog" + "net/http" + "slices" + "time" +) + +// Dialer represents the options for establishing a Conn with real-time activity services with DialContext or Dial. +type Dialer struct { + Options *websocket.DialOptions + ErrorLog *slog.Logger +} + +// Dial calls DialContext with a 15 seconds timeout. +func (d Dialer) Dial(src xsapi.TokenSource) (*Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + return d.DialContext(ctx, src) +} + +// DialContext establishes a connection with real-time activity service. A context.Context is used to control the +// scene real-timely. An authorization token may be used for configuring an HTTP header to Options. An error may be +// returned during the dial of websocket connection. +func (d Dialer) DialContext(ctx context.Context, src xsapi.TokenSource) (*Conn, error) { + if d.ErrorLog == nil { + d.ErrorLog = slog.Default() + } + if d.Options == nil { + d.Options = &websocket.DialOptions{} + } + if !slices.Contains(d.Options.Subprotocols, subprotocol) { + d.Options.Subprotocols = append(d.Options.Subprotocols, subprotocol) + } + if d.Options.HTTPHeader == nil { + d.Options.HTTPHeader = make(http.Header) + } + + if d.Options.HTTPClient == nil { + d.Options.HTTPClient = &http.Client{} + } + internal.SetTransport(d.Options.HTTPClient, src) + + c, _, err := websocket.Dial(ctx, connectURL, d.Options) + if err != nil { + return nil, err + } + conn := &Conn{ + conn: c, + log: d.ErrorLog, + subscriptions: make(map[uint32]*Subscription), + } + for i := 0; i < cap(conn.expected); i++ { + conn.expected[i] = make(map[uint32]chan<- *handshake) + } + go conn.read() + return conn, nil +} + +const ( + // connectURL is the URL used to establish a websocket connection with real-time activity services. It is + // generally present at websocket.Dial with other websocket.DialOptions, specifically along with subprotocol. + connectURL = "wss://rta.xboxlive.com/connect" + // subprotocol is the subprotocol used with connectURL, to establish a websocket connection. + subprotocol = "rta.xboxlive.com.V2" +) diff --git a/rta/handshake.go b/rta/handshake.go new file mode 100644 index 0000000..b54a9da --- /dev/null +++ b/rta/handshake.go @@ -0,0 +1,91 @@ +package rta + +import ( + "encoding/json" + "strconv" + "strings" +) + +type handshake struct { + sequence uint32 + status int32 + payload []json.RawMessage +} + +const ( + typeSubscribe uint32 = iota + 1 + typeUnsubscribe + typeEvent + typeResync +) + +const ( + operationSubscribe uint8 = iota + operationUnsubscribe + operationCapacity // The capacity of expected handshake uses. +) + +func typeToOperation(typ uint32) uint8 { + switch typ { + case typeSubscribe: + return operationSubscribe + case typeUnsubscribe: + return operationUnsubscribe + default: + panic("unreachable") + } +} + +func operationToType(op uint8) uint32 { + switch op { + case operationSubscribe: + return typeSubscribe + case operationUnsubscribe: + return typeUnsubscribe + default: + panic("unreachable") + } +} + +func (c *Conn) shake(op uint8, sequence uint32, payload []any) (<-chan *handshake, error) { + if err := c.write(operationToType(op), append([]any{sequence}, payload...)); err != nil { + return nil, err + } + hand := make(chan *handshake) + c.expectedMu.Lock() + c.expected[op][sequence] = hand + c.expectedMu.Unlock() + return hand, nil +} + +func (c *Conn) release(op uint8, sequence uint32) { + c.expectedMu.Lock() + delete(c.expected[op], sequence) + c.expectedMu.Unlock() +} + +type UnexpectedStatusError struct { + Code int32 + Message string +} + +func (e *UnexpectedStatusError) Error() string { + b := &strings.Builder{} + b.WriteString("rta: code ") + b.WriteString(strconv.FormatInt(int64(e.Code), 10)) + if e.Message != "" { + b.WriteByte(':') + b.WriteByte(' ') + b.WriteString(e.Message) + } + return b.String() +} + +const ( + StatusOK int32 = iota + StatusUnknownResource + StatusSubscriptionLimitReached + StatusNoResourceData + StatusThrottled = 1001 + StatusServiceUnavailable = 1002 +) diff --git a/token.go b/token.go new file mode 100644 index 0000000..a3731b5 --- /dev/null +++ b/token.go @@ -0,0 +1,28 @@ +package xsapi + +import ( + "net/http" +) + +type Token interface { + // SetAuthHeader sets an 'Authorization' and a 'Signature' header in the request. + SetAuthHeader(req *http.Request) + // String formats the Token into a string that can be set as an 'Authorization' header + // or a field in requests. It usually follows the format 'XBL3.0 x=;'. + String() string + // DisplayClaims returns the DisplayClaims, which contains an information for a user. + // It is usually claimed from the response body returned from the authorization. + DisplayClaims() DisplayClaims +} + +// TokenSource implements a Token method that returns a Token. +type TokenSource interface { + Token() (Token, error) +} + +// DisplayClaims contains an information for user of Token. +type DisplayClaims struct { + GamerTag string `json:"gtg"` + XUID string `json:"xid"` + UserHash string `json:"uhs"` +} diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..99a116c --- /dev/null +++ b/transport.go @@ -0,0 +1,63 @@ +package xsapi + +import ( + "errors" + "net/http" +) + +// Transport is an http.RoundTripper that makes authenticated Xbox Live requests, +// wrapping a base RoundTripper and adding an 'Authorization' header and a 'Signature' +// header with a token from the supplied Sources. +type Transport struct { + Source TokenSource + + Base http.RoundTripper +} + +// RoundTrip authorizes and authenticates the request using the +// [Token.SetAuthHeader] from Source of the Transport. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + reqBodyClosed := false + if req.Body != nil { + defer func() { + if !reqBodyClosed { + req.Body.Close() + } + }() + } + + if t.Source == nil { + return nil, errors.New("xsapi: Transport's Source is nil") + } + token, err := t.Source.Token() + if err != nil { + return nil, err + } + + req2 := cloneRequest(req) + token.SetAuthHeader(req2) + + reqBodyClosed = true + return t.base().RoundTrip(req2) +} + +func (t *Transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + return r2 +}