diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9b3c5a286..80615c012 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -26,7 +26,7 @@ jobs: make generate - name: golangci-lint - uses: golangci/golangci-lint-action@v7 + uses: golangci/golangci-lint-action@v4 with: version: latest diff --git a/balancer/catabalancer/catalyst_balancer.go b/balancer/catabalancer/catalyst_balancer.go index 21a2d9583..98f7a8955 100644 --- a/balancer/catabalancer/catalyst_balancer.go +++ b/balancer/catabalancer/catalyst_balancer.go @@ -339,7 +339,7 @@ func (c *CataBalancer) refreshNodes(ctx context.Context) (stats, error) { if err != nil { return s, fmt.Errorf("failed to query node stats: %w", err) } - defer rows.Close() + defer rows.Close() // nolint:errcheck // Process the result set for rows.Next() { diff --git a/balancer/mist/mist_balancer.go b/balancer/mist/mist_balancer.go index 2900feb06..0d43ee26b 100644 --- a/balancer/mist/mist_balancer.go +++ b/balancer/mist/mist_balancer.go @@ -176,7 +176,7 @@ func (b *MistBalancer) changeLoadBalancerServers(ctx context.Context, server, ac glog.Errorf("Error making request: %v", err) return nil, err } - defer resp.Body.Close() + defer resp.Body.Close() // nolint:errcheck bytes, err := io.ReadAll(resp.Body) @@ -212,7 +212,7 @@ func (b *MistBalancer) getMistLoadBalancerServers(ctx context.Context) (map[stri glog.Errorf("Error making request: %v", err) return nil, err } - defer resp.Body.Close() + defer resp.Body.Close() // nolint:errcheck if resp.StatusCode != http.StatusOK { b, _ := io.ReadAll(resp.Body) @@ -473,7 +473,7 @@ func (b *MistBalancer) mistUtilLoadRequest(ctx context.Context, route, stream, l if err != nil { return "", err } - defer resp.Body.Close() + defer resp.Body.Close() // nolint:errcheck if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("GET request '%s' failed with http status code %d", murl, resp.StatusCode) diff --git a/c2pa/c2pa_test.go b/c2pa/c2pa_test.go index b248a0a09..ad97825f8 100644 --- a/c2pa/c2pa_test.go +++ b/c2pa/c2pa_test.go @@ -19,7 +19,7 @@ func TestSign(t *testing.T) { outFile := "test/tiny_output_signed.mp4" c := NewC2PA("es256", "test/es256_private.key", "test/es256_certs.pem") - defer os.Remove(outFile) + defer os.Remove(outFile) // nolint:errcheck err = c.SignFile("test/tiny.mp4", outFile, "Tiny", "") @@ -38,7 +38,7 @@ func TestSignWithParent(t *testing.T) { outFile := "test/tiny_cut_output_signed.mp4" c := NewC2PA("es256", "test/es256_private.key", "test/es256_certs.pem") - defer os.Remove(outFile) + defer os.Remove(outFile) // nolint:errcheck err = c.SignFile("test/tiny_cut.mp4", outFile, "Tiny", "test/tiny_signed.mp4") diff --git a/clients/broadcaster_remote.go b/clients/broadcaster_remote.go index 09c8f3684..177c91c5a 100644 --- a/clients/broadcaster_remote.go +++ b/clients/broadcaster_remote.go @@ -84,7 +84,7 @@ func findBroadcaster(c Credentials) (BroadcasterList, error) { if err != nil { return BroadcasterList{}, fmt.Errorf("http do(%s): %v", requestURL, err) } - defer res.Body.Close() + defer res.Body.Close() // nolint:errcheck if !httpOk(res.StatusCode) { return BroadcasterList{}, fmt.Errorf("http GET(%s) returned %d %s", requestURL, res.StatusCode, res.Status) @@ -130,7 +130,7 @@ func CreateStream(c Credentials, streamName string, profiles []video.EncodedProf if err != nil { return "", fmt.Errorf("http do(%s): %v", requestURL, err) } - defer res.Body.Close() + defer res.Body.Close() // nolint:errcheck if !httpOk(res.StatusCode) { return "", fmt.Errorf("http POST(%s) returned %d %s", requestURL, res.StatusCode, res.Status) @@ -164,7 +164,7 @@ func ReleaseManifestID(c Credentials, manifestId string) error { if err != nil { return fmt.Errorf("Releasing Manifest ID failed. URL: %s, manifestID: %s, err: %s", requestURL, manifestId, err) } - defer res.Body.Close() + defer res.Body.Close() // nolint:errcheck if !httpOk(res.StatusCode) { return fmt.Errorf("Releasing Manifest ID failed. URL: %s, manifestID: %s, HTTP Code: %s", requestURL, manifestId, res.Status) diff --git a/clients/manifest.go b/clients/manifest.go index e84e76cf0..05d1d31c7 100644 --- a/clients/manifest.go +++ b/clients/manifest.go @@ -77,7 +77,7 @@ func RecordingBackupCheck(requestID string, primaryManifestURL, osTransferURL *u var rc io.ReadCloser rc, actualSegURL, err = GetFileWithBackup(context.Background(), requestID, segURL.String(), dStorage) if rc != nil { - rc.Close() + rc.Close() // nolint:errcheck } return err }, DownloadRetryBackoff()) @@ -189,7 +189,7 @@ func downloadManifest(requestID, sourceManifestOSURL string) (playlist m3u8.Play } return err } - defer rc.Close() + defer rc.Close() // nolint:errcheck data := new(bytes.Buffer) _, err = data.ReadFrom(rc) @@ -390,7 +390,7 @@ func ClipInputManifest(requestID, sourceURL, clipTargetUrl string, startTimeUnix if err != nil { return nil, fmt.Errorf("error clipping: failed to create temp clipping storage dir: %w", err) } - defer os.RemoveAll(clipStorageDir) + defer os.RemoveAll(clipStorageDir) // nolint:errcheck // Download start/end segments and clip for i, v := range segsToClip { @@ -451,7 +451,7 @@ func ClipInputManifest(requestID, sourceURL, clipTargetUrl string, startTimeUnix if err != nil { return nil, fmt.Errorf("error clipping: failed to open clipped segment %d: %w", v.SeqId, err) } - defer clippedSegmentFile.Close() + defer clippedSegmentFile.Close() // nolint:errcheck clippedSegmentOSFilename := "clip_" + strconv.FormatUint(v.SeqId, 10) + ".ts" err = UploadToOSURL(clipTargetUrl, clippedSegmentOSFilename, clippedSegmentFile, MaxCopyFileDuration) @@ -553,7 +553,7 @@ func GetFirstRenditionURL(requestID string, masterManifestURL *url.URL) (*url.UR if err != nil { return fmt.Errorf("error downloading manifest %s: %w", masterManifestURL.Redacted(), err) } - defer rc.Close() + defer rc.Close() // nolint:errcheck playlist, playlistType, err = m3u8.DecodeFrom(rc, true) if err != nil { diff --git a/clients/mist_client_test.go b/clients/mist_client_test.go index 6c1ed2c9b..0c0dbd202 100644 --- a/clients/mist_client_test.go +++ b/clients/mist_client_test.go @@ -234,6 +234,7 @@ func TestItCanParseAMistStreamStatus(t *testing.T) { mc := &MistClient{ HttpReqUrl: svr.URL, + httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}), } msi, err := mc.GetStreamInfo("some-stream-name") @@ -258,6 +259,7 @@ func TestItCanParseAMistStreamErrorStatus(t *testing.T) { mc := &MistClient{ HttpReqUrl: svr.URL, + httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}), } _, err := mc.GetStreamInfo("some-stream-name") @@ -286,6 +288,7 @@ func TestItRetriesFailingRequests(t *testing.T) { mc := &MistClient{ HttpReqUrl: svr.URL, + httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}), } _, err := mc.GetStreamInfo("some-stream-name") @@ -307,6 +310,7 @@ func TestItFailsWhenMaxRetriesReached(t *testing.T) { mc := &MistClient{ HttpReqUrl: svr.URL, + httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}), } _, err := mc.GetStreamInfo("some-stream-name") @@ -385,8 +389,9 @@ func TestItCanGetStreamStats(t *testing.T) { defer svr.Close() mc := &MistClient{ - ApiUrl: svr.URL, - cache: cache.New(200*time.Millisecond, time.Minute), + ApiUrl: svr.URL, + cache: cache.New(200*time.Millisecond, time.Minute), + httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}), } status, err := mc.GetState() diff --git a/config/cli.go b/config/cli.go index 89437ffaf..ed7ba4c22 100644 --- a/config/cli.go +++ b/config/cli.go @@ -96,9 +96,12 @@ type Cli struct { SerfEventBuffer int SerfMaxQueueDepth int - LBReplaceHostMatch string - LBReplaceHostPercent int - LBReplaceHostList []string + LBReplaceHostMatch string + LBReplaceHostPercent int + LBReplaceHostList []string + LBReplaceDomains map[string]string + LBReplaceDomainReferers []string + LBReplaceDomainQueryParams map[string]string } // Return our own URL for callback trigger purposes diff --git a/handlers/geolocation/geolocation.go b/handlers/geolocation/geolocation.go index 8735cd537..41050281e 100644 --- a/handlers/geolocation/geolocation.go +++ b/handlers/geolocation/geolocation.go @@ -91,7 +91,6 @@ func NewGeolocationHandlersCollection(balancer balancer.Balancer, config config. // Redirect an incoming user to: CDN (only for /hls), closest node (geolocate) // or another service (like mist HLS) on the current host for playback. func (c *GeolocationHandlersCollection) RedirectHandler() httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) { host := r.Host pathType, prefix, playbackID, pathTmpl := parsePlaybackID(r.URL.Path) @@ -193,13 +192,15 @@ func (c *GeolocationHandlersCollection) RedirectHandler() httprouter.Handle { } rPath := fmt.Sprintf(pathTmpl, fullPlaybackID) - rURL := fmt.Sprintf("%s://%s%s?%s", protocol(r), bestNode, rPath, r.URL.RawQuery) - rURL, err = c.resolveNodeURL(rURL) + rURL, err := c.resolveNodeURL(fmt.Sprintf("%s://%s%s?%s", protocol(r), bestNode, rPath, r.URL.RawQuery)) if err != nil { glog.Errorf("failed to resolve node URL playbackID=%s err=%s", playbackID, err) w.WriteHeader(http.StatusInternalServerError) return } + + c.alternativeNodeDomain(r, rURL) + var redirectType = "playback" if isStudioReq { redirectType = "ingest" @@ -214,37 +215,68 @@ func (c *GeolocationHandlersCollection) RedirectHandler() httprouter.Handle { "lon": lon, }) glog.Infof(string(jsonRedirectInfo)) - http.Redirect(w, r, rURL, http.StatusTemporaryRedirect) + http.Redirect(w, r, rURL.String(), http.StatusTemporaryRedirect) + } +} + +// alternativeNodeDomain switches the domain if certain conditions are matched +func (c *GeolocationHandlersCollection) alternativeNodeDomain(req *http.Request, redirectUrl *url.URL) { + if len(c.Config.LBReplaceDomains) < 1 || req == nil || req.URL == nil || redirectUrl == nil { + return + } + + switchDomain := false + for _, referer := range c.Config.LBReplaceDomainReferers { + switchDomain = strings.Contains(req.Header.Get("Referer"), referer) + if switchDomain { + break + } + } + + if !switchDomain { + for k, v := range c.Config.LBReplaceDomainQueryParams { + switchDomain = req.URL.Query().Get(k) == v + if switchDomain { + break + } + } + } + + if switchDomain { + for old, replace := range c.Config.LBReplaceDomains { + redirectUrl.Host = strings.Replace(redirectUrl.Host, old, replace, 1) + } } } // Given a dtsc:// or https:// url, resolve the proper address of the node via serf tags -func (c *GeolocationHandlersCollection) resolveNodeURL(streamURL string) (string, error) { +func (c *GeolocationHandlersCollection) resolveNodeURL(streamURL string) (*url.URL, error) { u, err := url.Parse(streamURL) if err != nil { - return "", err + return nil, err } nodeName := u.Host protocol := u.Scheme member, err := c.clusterMember(map[string]string{}, "alive", nodeName) if err != nil { - return "", err + return nil, err } addr, has := member.Tags[protocol] if !has { glog.V(7).Infof("no tag found, not tag resolving protocol=%s nodeName=%s", protocol, nodeName) - return streamURL, nil + return u, nil } u2, err := url.Parse(addr) if err != nil { err = fmt.Errorf("node has unparsable tag!! nodeName=%s protocol=%s tag=%s", nodeName, protocol, addr) glog.Error(err) - return "", err + return nil, err } u2.Path = filepath.Join(u2.Path, u.Path) u2.RawQuery = u.RawQuery - return u2.String(), nil + + return u2, nil } func (c *GeolocationHandlersCollection) clusterMember(filter map[string]string, status, name string) (cluster.Member, error) { @@ -341,7 +373,7 @@ func (c *GeolocationHandlersCollection) resolveReplicatedStream(dtscURL string, return "push://", nil } glog.V(7).Infof("replying to Mist STREAM_SOURCE request=%s response=%s", streamName, outURL) - return outURL, nil + return outURL.String(), nil } func (c *GeolocationHandlersCollection) getStreamPull(playbackID string, retryCount int) (string, error) { diff --git a/handlers/geolocation/geolocation_test.go b/handlers/geolocation/geolocation_test.go index ce674fac9..307a10924 100644 --- a/handlers/geolocation/geolocation_test.go +++ b/handlers/geolocation/geolocation_test.go @@ -614,3 +614,78 @@ func TestStreamPullRateLimit(t *testing.T) { time.Sleep(2 * time.Second) require.False(rateLimit.shouldLimit(playbackID1)) } + +func TestGeolocationHandlersCollection_alternativeNodeDomain(t *testing.T) { + conf := config.Cli{ + LBReplaceDomains: map[string]string{"domain1.net": "domain2.net"}, + LBReplaceDomainReferers: []string{"abc"}, + LBReplaceDomainQueryParams: map[string]string{"qpkey": "qpvalue"}, + } + type args struct { + req *http.Request + redirectUrl *url.URL + } + tests := []struct { + name string + args args + expectedHost string + }{ + { + name: "no replacement", + args: args{ + req: &http.Request{URL: &url.URL{}}, + redirectUrl: &url.URL{Host: "foo"}, + }, + expectedHost: "foo", + }, + { + name: "referer match", + args: args{ + req: &http.Request{URL: &url.URL{}, Header: map[string][]string{"Referer": {"abc"}}}, + redirectUrl: &url.URL{Host: "foo.domain1.net"}, + }, + expectedHost: "foo.domain2.net", + }, + { + name: "queryparam match", + args: args{ + req: &http.Request{URL: &url.URL{RawQuery: "qpkey=qpvalue"}}, + redirectUrl: &url.URL{Host: "foo.domain1.net"}, + }, + expectedHost: "foo.domain2.net", + }, + { + name: "no replacement - empty queryparam", + args: args{ + req: &http.Request{URL: &url.URL{RawQuery: "qpkey="}}, + redirectUrl: &url.URL{Host: "foo.domain1.net"}, + }, + expectedHost: "foo.domain1.net", + }, + { + name: "partial referer match", + args: args{ + req: &http.Request{URL: &url.URL{}, Header: map[string][]string{"Referer": {"abcd"}}}, + redirectUrl: &url.URL{Host: "foo.domain1.net"}, + }, + expectedHost: "foo.domain2.net", + }, + { + name: "partial domain match", + args: args{ + req: &http.Request{URL: &url.URL{}, Header: map[string][]string{"Referer": {"abc"}}}, + redirectUrl: &url.URL{Host: "foo.domain1.nets"}, + }, + expectedHost: "foo.domain2.nets", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &GeolocationHandlersCollection{ + Config: conf, + } + c.alternativeNodeDomain(tt.args.req, tt.args.redirectUrl) + require.Equal(t, tt.expectedHost, tt.args.redirectUrl.Host) + }) + } +} diff --git a/main.go b/main.go index 59714d6b3..12dfe64ab 100644 --- a/main.go +++ b/main.go @@ -141,6 +141,9 @@ func main() { fs.StringVar(&cli.LBReplaceHostMatch, "lb-replace-host-match", "", "What to match on the hostname for node replacement e.g. sto") config.CommaSliceFlag(fs, &cli.LBReplaceHostList, "lb-replace-host-list", []string{}, "List of hostnames to replace with for node replacement") fs.IntVar(&cli.LBReplaceHostPercent, "lb-replace-host-percent", 0, "Percentage of matching requests to replace host on") + config.CommaMapFlag(fs, &cli.LBReplaceDomains, "lb-replace-domains", map[string]string{}, "Comma-separated map of domains to replace in load balancing. e.g. domain1.com=domain2.com") + config.CommaSliceFlag(fs, &cli.LBReplaceDomainReferers, "lb-replace-domain-referers", []string{}, "List of referer headers to match for load balancing domain replacement") + config.CommaMapFlag(fs, &cli.LBReplaceDomainQueryParams, "lb-replace-domain-queryparams", map[string]string{}, "Comma-separated map of queryparams to match for load balancing domain replacement. e.g. queryparamName=queryparamValue") pprofPort := fs.Int("pprof-port", 6061, "Pprof listen port") fs.String("send-audio", "", "[DEPRECATED] ignored, will be removed")