Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
make generate

- name: golangci-lint
uses: golangci/golangci-lint-action@v7
uses: golangci/golangci-lint-action@v4
with:
version: latest

Expand Down
2 changes: 1 addition & 1 deletion balancer/catabalancer/catalyst_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (c *CataBalancer) refreshNodes(ctx context.Context) (stats, error) {
if err != nil {
return s, fmt.Errorf("failed to query node stats: %w", err)
}
defer rows.Close()
defer rows.Close() // nolint:errcheck

// Process the result set
for rows.Next() {
Expand Down
6 changes: 3 additions & 3 deletions balancer/mist/mist_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (b *MistBalancer) changeLoadBalancerServers(ctx context.Context, server, ac
glog.Errorf("Error making request: %v", err)
return nil, err
}
defer resp.Body.Close()
defer resp.Body.Close() // nolint:errcheck

bytes, err := io.ReadAll(resp.Body)

Expand Down Expand Up @@ -212,7 +212,7 @@ func (b *MistBalancer) getMistLoadBalancerServers(ctx context.Context) (map[stri
glog.Errorf("Error making request: %v", err)
return nil, err
}
defer resp.Body.Close()
defer resp.Body.Close() // nolint:errcheck

if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
Expand Down Expand Up @@ -473,7 +473,7 @@ func (b *MistBalancer) mistUtilLoadRequest(ctx context.Context, route, stream, l
if err != nil {
return "", err
}
defer resp.Body.Close()
defer resp.Body.Close() // nolint:errcheck

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GET request '%s' failed with http status code %d", murl, resp.StatusCode)
Expand Down
4 changes: 2 additions & 2 deletions c2pa/c2pa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestSign(t *testing.T) {

outFile := "test/tiny_output_signed.mp4"
c := NewC2PA("es256", "test/es256_private.key", "test/es256_certs.pem")
defer os.Remove(outFile)
defer os.Remove(outFile) // nolint:errcheck

err = c.SignFile("test/tiny.mp4", outFile, "Tiny", "")

Expand All @@ -38,7 +38,7 @@ func TestSignWithParent(t *testing.T) {

outFile := "test/tiny_cut_output_signed.mp4"
c := NewC2PA("es256", "test/es256_private.key", "test/es256_certs.pem")
defer os.Remove(outFile)
defer os.Remove(outFile) // nolint:errcheck

err = c.SignFile("test/tiny_cut.mp4", outFile, "Tiny", "test/tiny_signed.mp4")

Expand Down
6 changes: 3 additions & 3 deletions clients/broadcaster_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func findBroadcaster(c Credentials) (BroadcasterList, error) {
if err != nil {
return BroadcasterList{}, fmt.Errorf("http do(%s): %v", requestURL, err)
}
defer res.Body.Close()
defer res.Body.Close() // nolint:errcheck

if !httpOk(res.StatusCode) {
return BroadcasterList{}, fmt.Errorf("http GET(%s) returned %d %s", requestURL, res.StatusCode, res.Status)
Expand Down Expand Up @@ -130,7 +130,7 @@ func CreateStream(c Credentials, streamName string, profiles []video.EncodedProf
if err != nil {
return "", fmt.Errorf("http do(%s): %v", requestURL, err)
}
defer res.Body.Close()
defer res.Body.Close() // nolint:errcheck

if !httpOk(res.StatusCode) {
return "", fmt.Errorf("http POST(%s) returned %d %s", requestURL, res.StatusCode, res.Status)
Expand Down Expand Up @@ -164,7 +164,7 @@ func ReleaseManifestID(c Credentials, manifestId string) error {
if err != nil {
return fmt.Errorf("Releasing Manifest ID failed. URL: %s, manifestID: %s, err: %s", requestURL, manifestId, err)
}
defer res.Body.Close()
defer res.Body.Close() // nolint:errcheck

if !httpOk(res.StatusCode) {
return fmt.Errorf("Releasing Manifest ID failed. URL: %s, manifestID: %s, HTTP Code: %s", requestURL, manifestId, res.Status)
Expand Down
10 changes: 5 additions & 5 deletions clients/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func RecordingBackupCheck(requestID string, primaryManifestURL, osTransferURL *u
var rc io.ReadCloser
rc, actualSegURL, err = GetFileWithBackup(context.Background(), requestID, segURL.String(), dStorage)
if rc != nil {
rc.Close()
rc.Close() // nolint:errcheck
}
return err
}, DownloadRetryBackoff())
Expand Down Expand Up @@ -189,7 +189,7 @@ func downloadManifest(requestID, sourceManifestOSURL string) (playlist m3u8.Play
}
return err
}
defer rc.Close()
defer rc.Close() // nolint:errcheck

data := new(bytes.Buffer)
_, err = data.ReadFrom(rc)
Expand Down Expand Up @@ -390,7 +390,7 @@ func ClipInputManifest(requestID, sourceURL, clipTargetUrl string, startTimeUnix
if err != nil {
return nil, fmt.Errorf("error clipping: failed to create temp clipping storage dir: %w", err)
}
defer os.RemoveAll(clipStorageDir)
defer os.RemoveAll(clipStorageDir) // nolint:errcheck

// Download start/end segments and clip
for i, v := range segsToClip {
Expand Down Expand Up @@ -451,7 +451,7 @@ func ClipInputManifest(requestID, sourceURL, clipTargetUrl string, startTimeUnix
if err != nil {
return nil, fmt.Errorf("error clipping: failed to open clipped segment %d: %w", v.SeqId, err)
}
defer clippedSegmentFile.Close()
defer clippedSegmentFile.Close() // nolint:errcheck

clippedSegmentOSFilename := "clip_" + strconv.FormatUint(v.SeqId, 10) + ".ts"
err = UploadToOSURL(clipTargetUrl, clippedSegmentOSFilename, clippedSegmentFile, MaxCopyFileDuration)
Expand Down Expand Up @@ -553,7 +553,7 @@ func GetFirstRenditionURL(requestID string, masterManifestURL *url.URL) (*url.UR
if err != nil {
return fmt.Errorf("error downloading manifest %s: %w", masterManifestURL.Redacted(), err)
}
defer rc.Close()
defer rc.Close() // nolint:errcheck

playlist, playlistType, err = m3u8.DecodeFrom(rc, true)
if err != nil {
Expand Down
9 changes: 7 additions & 2 deletions clients/mist_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ func TestItCanParseAMistStreamStatus(t *testing.T) {

mc := &MistClient{
HttpReqUrl: svr.URL,
httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}),
}

msi, err := mc.GetStreamInfo("some-stream-name")
Expand All @@ -258,6 +259,7 @@ func TestItCanParseAMistStreamErrorStatus(t *testing.T) {

mc := &MistClient{
HttpReqUrl: svr.URL,
httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}),
}

_, err := mc.GetStreamInfo("some-stream-name")
Expand Down Expand Up @@ -286,6 +288,7 @@ func TestItRetriesFailingRequests(t *testing.T) {

mc := &MistClient{
HttpReqUrl: svr.URL,
httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}),
}

_, err := mc.GetStreamInfo("some-stream-name")
Expand All @@ -307,6 +310,7 @@ func TestItFailsWhenMaxRetriesReached(t *testing.T) {

mc := &MistClient{
HttpReqUrl: svr.URL,
httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}),
}

_, err := mc.GetStreamInfo("some-stream-name")
Expand Down Expand Up @@ -385,8 +389,9 @@ func TestItCanGetStreamStats(t *testing.T) {
defer svr.Close()

mc := &MistClient{
ApiUrl: svr.URL,
cache: cache.New(200*time.Millisecond, time.Minute),
ApiUrl: svr.URL,
cache: cache.New(200*time.Millisecond, time.Minute),
httpClient: newRetryableClient(&http.Client{Timeout: MistClientTimeout}),
}

status, err := mc.GetState()
Expand Down
9 changes: 6 additions & 3 deletions config/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ type Cli struct {
SerfEventBuffer int
SerfMaxQueueDepth int

LBReplaceHostMatch string
LBReplaceHostPercent int
LBReplaceHostList []string
LBReplaceHostMatch string
LBReplaceHostPercent int
LBReplaceHostList []string
LBReplaceDomains map[string]string
LBReplaceDomainReferers []string
LBReplaceDomainQueryParams map[string]string
}

// Return our own URL for callback trigger purposes
Expand Down
54 changes: 43 additions & 11 deletions handlers/geolocation/geolocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ func NewGeolocationHandlersCollection(balancer balancer.Balancer, config config.
// Redirect an incoming user to: CDN (only for /hls), closest node (geolocate)
// or another service (like mist HLS) on the current host for playback.
func (c *GeolocationHandlersCollection) RedirectHandler() httprouter.Handle {

return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
host := r.Host
pathType, prefix, playbackID, pathTmpl := parsePlaybackID(r.URL.Path)
Expand Down Expand Up @@ -193,13 +192,15 @@ func (c *GeolocationHandlersCollection) RedirectHandler() httprouter.Handle {
}

rPath := fmt.Sprintf(pathTmpl, fullPlaybackID)
rURL := fmt.Sprintf("%s://%s%s?%s", protocol(r), bestNode, rPath, r.URL.RawQuery)
rURL, err = c.resolveNodeURL(rURL)
rURL, err := c.resolveNodeURL(fmt.Sprintf("%s://%s%s?%s", protocol(r), bestNode, rPath, r.URL.RawQuery))
if err != nil {
glog.Errorf("failed to resolve node URL playbackID=%s err=%s", playbackID, err)
w.WriteHeader(http.StatusInternalServerError)
return
}

c.alternativeNodeDomain(r, rURL)

var redirectType = "playback"
if isStudioReq {
redirectType = "ingest"
Expand All @@ -214,37 +215,68 @@ func (c *GeolocationHandlersCollection) RedirectHandler() httprouter.Handle {
"lon": lon,
})
glog.Infof(string(jsonRedirectInfo))
http.Redirect(w, r, rURL, http.StatusTemporaryRedirect)
http.Redirect(w, r, rURL.String(), http.StatusTemporaryRedirect)
}
}

// alternativeNodeDomain switches the domain if certain conditions are matched
func (c *GeolocationHandlersCollection) alternativeNodeDomain(req *http.Request, redirectUrl *url.URL) {
if len(c.Config.LBReplaceDomains) < 1 || req == nil || req.URL == nil || redirectUrl == nil {
return
}

switchDomain := false
for _, referer := range c.Config.LBReplaceDomainReferers {
switchDomain = strings.Contains(req.Header.Get("Referer"), referer)
if switchDomain {
break
}
}

if !switchDomain {
for k, v := range c.Config.LBReplaceDomainQueryParams {
switchDomain = req.URL.Query().Get(k) == v
if switchDomain {
break
}
}
}

if switchDomain {
for old, replace := range c.Config.LBReplaceDomains {
redirectUrl.Host = strings.Replace(redirectUrl.Host, old, replace, 1)
}
}
}

// Given a dtsc:// or https:// url, resolve the proper address of the node via serf tags
func (c *GeolocationHandlersCollection) resolveNodeURL(streamURL string) (string, error) {
func (c *GeolocationHandlersCollection) resolveNodeURL(streamURL string) (*url.URL, error) {
u, err := url.Parse(streamURL)
if err != nil {
return "", err
return nil, err
}
nodeName := u.Host
protocol := u.Scheme

member, err := c.clusterMember(map[string]string{}, "alive", nodeName)
if err != nil {
return "", err
return nil, err
}
addr, has := member.Tags[protocol]
if !has {
glog.V(7).Infof("no tag found, not tag resolving protocol=%s nodeName=%s", protocol, nodeName)
return streamURL, nil
return u, nil
}
u2, err := url.Parse(addr)
if err != nil {
err = fmt.Errorf("node has unparsable tag!! nodeName=%s protocol=%s tag=%s", nodeName, protocol, addr)
glog.Error(err)
return "", err
return nil, err
}
u2.Path = filepath.Join(u2.Path, u.Path)
u2.RawQuery = u.RawQuery
return u2.String(), nil

return u2, nil
}

func (c *GeolocationHandlersCollection) clusterMember(filter map[string]string, status, name string) (cluster.Member, error) {
Expand Down Expand Up @@ -341,7 +373,7 @@ func (c *GeolocationHandlersCollection) resolveReplicatedStream(dtscURL string,
return "push://", nil
}
glog.V(7).Infof("replying to Mist STREAM_SOURCE request=%s response=%s", streamName, outURL)
return outURL, nil
return outURL.String(), nil
}

func (c *GeolocationHandlersCollection) getStreamPull(playbackID string, retryCount int) (string, error) {
Expand Down
75 changes: 75 additions & 0 deletions handlers/geolocation/geolocation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,78 @@ func TestStreamPullRateLimit(t *testing.T) {
time.Sleep(2 * time.Second)
require.False(rateLimit.shouldLimit(playbackID1))
}

func TestGeolocationHandlersCollection_alternativeNodeDomain(t *testing.T) {
conf := config.Cli{
LBReplaceDomains: map[string]string{"domain1.net": "domain2.net"},
LBReplaceDomainReferers: []string{"abc"},
LBReplaceDomainQueryParams: map[string]string{"qpkey": "qpvalue"},
}
type args struct {
req *http.Request
redirectUrl *url.URL
}
tests := []struct {
name string
args args
expectedHost string
}{
{
name: "no replacement",
args: args{
req: &http.Request{URL: &url.URL{}},
redirectUrl: &url.URL{Host: "foo"},
},
expectedHost: "foo",
},
{
name: "referer match",
args: args{
req: &http.Request{URL: &url.URL{}, Header: map[string][]string{"Referer": {"abc"}}},
redirectUrl: &url.URL{Host: "foo.domain1.net"},
},
expectedHost: "foo.domain2.net",
},
{
name: "queryparam match",
args: args{
req: &http.Request{URL: &url.URL{RawQuery: "qpkey=qpvalue"}},
redirectUrl: &url.URL{Host: "foo.domain1.net"},
},
expectedHost: "foo.domain2.net",
},
{
name: "no replacement - empty queryparam",
args: args{
req: &http.Request{URL: &url.URL{RawQuery: "qpkey="}},
redirectUrl: &url.URL{Host: "foo.domain1.net"},
},
expectedHost: "foo.domain1.net",
},
{
name: "partial referer match",
args: args{
req: &http.Request{URL: &url.URL{}, Header: map[string][]string{"Referer": {"abcd"}}},
redirectUrl: &url.URL{Host: "foo.domain1.net"},
},
expectedHost: "foo.domain2.net",
},
{
name: "partial domain match",
args: args{
req: &http.Request{URL: &url.URL{}, Header: map[string][]string{"Referer": {"abc"}}},
redirectUrl: &url.URL{Host: "foo.domain1.nets"},
},
expectedHost: "foo.domain2.nets",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &GeolocationHandlersCollection{
Config: conf,
}
c.alternativeNodeDomain(tt.args.req, tt.args.redirectUrl)
require.Equal(t, tt.expectedHost, tt.args.redirectUrl.Host)
})
}
}
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ func main() {
fs.StringVar(&cli.LBReplaceHostMatch, "lb-replace-host-match", "", "What to match on the hostname for node replacement e.g. sto")
config.CommaSliceFlag(fs, &cli.LBReplaceHostList, "lb-replace-host-list", []string{}, "List of hostnames to replace with for node replacement")
fs.IntVar(&cli.LBReplaceHostPercent, "lb-replace-host-percent", 0, "Percentage of matching requests to replace host on")
config.CommaMapFlag(fs, &cli.LBReplaceDomains, "lb-replace-domains", map[string]string{}, "Comma-separated map of domains to replace in load balancing. e.g. domain1.com=domain2.com")
config.CommaSliceFlag(fs, &cli.LBReplaceDomainReferers, "lb-replace-domain-referers", []string{}, "List of referer headers to match for load balancing domain replacement")
config.CommaMapFlag(fs, &cli.LBReplaceDomainQueryParams, "lb-replace-domain-queryparams", map[string]string{}, "Comma-separated map of queryparams to match for load balancing domain replacement. e.g. queryparamName=queryparamValue")
pprofPort := fs.Int("pprof-port", 6061, "Pprof listen port")

fs.String("send-audio", "", "[DEPRECATED] ignored, will be removed")
Expand Down
Loading