Skip to content

Commit

Permalink
Merge pull request #6 from libdns/bug-fix-record-creation
Browse files Browse the repository at this point in the history
Add tests and fix several bugs (closes #1)
  • Loading branch information
omarjatoi authored May 18, 2024
2 parents 72e1ced + be9ffd7 commit dfa9220
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 32 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ This provider expects the following configuration:
- `ACCOUNT_ID`: identifier for the account (only needed if using a user access token), see [accounts documentation](https://developer.dnsimple.com/v2/accounts/)
- `API_URL`: hostname for the API to use (defaults to `api.dnsimple.com`), only useful for testing purposes, see [sandox documentation](https://developer.dnsimple.com/sandbox/)

## Testing

In order to run the tests, you need to create an account on the [DNSimple sandbox environment](https://developer.dnsimple.com/sandbox/). After setup, create a new DNS zone, and create an `API_ACCESS_TOKEN` and take note of both. You will need both these values to run tests.

```
$ TEST_ZONE=example.com TEST_API_ACCESS_TOKEN=you_api_access_token go test -v
=== RUN Test_AppendRecords
--- PASS: Test_AppendRecords (1.23s)
=== RUN Test_DeleteRecords
--- PASS: Test_DeleteRecords (0.59s)
=== RUN Test_GetRecords
--- PASS: Test_GetRecords (0.58s)
=== RUN Test_SetRecords
--- PASS: Test_SetRecords (1.14s)
PASS
ok github.com/libdns/dnsimple 3.666s
```

## License

Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
91 changes: 60 additions & 31 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (p *Provider) initClient(ctx context.Context) {
p.once.Do(func() {
// Create new DNSimple client using the provided access token.
tc := dnsimple.StaticTokenHTTPClient(ctx, p.APIAccessToken)
c := *dnsimple.NewClient(tc)
c := dnsimple.NewClient(tc)
// Set the API URL if using a non-default API hostname (e.g. sandbox).
if p.APIURL != "" {
c.BaseURL = p.APIURL
Expand All @@ -43,16 +43,13 @@ func (p *Provider) initClient(ctx context.Context) {
p.AccountID = accountID
}

p.client = c
p.client = *c
})
}

// GetRecords lists all the records in the zone.
func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record, error) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.initClient(ctx)

// Internal helper function to fetch records from the provider, note that this function assumes
// the called is holding a lock on the mutex and has already initialized the client.
func (p *Provider) getRecordsFromProvider(ctx context.Context, zone string) ([]libdns.Record, error) {
var records []libdns.Record

resp, err := p.client.Zones.ListRecords(ctx, p.AccountID, unFQDN(zone), &dnsimple.ZoneRecordListOptions{})
Expand All @@ -65,7 +62,7 @@ func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record
Type: r.Type,
Name: r.Name,
Value: r.Content,
TTL: time.Duration(r.TTL),
TTL: time.Duration(r.TTL * int(time.Second)),
Priority: uint(r.Priority),
}
records = append(records, record)
Expand All @@ -74,13 +71,19 @@ func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record
return records, nil
}

// AppendRecords adds records to the zone. It returns the records that were added.
func (p *Provider) AppendRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
// GetRecords lists all the records in the zone.
func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record, error) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.initClient(ctx)

var appendedRecords []libdns.Record
return p.getRecordsFromProvider(ctx, zone)
}

// Internal helper function that actually creates the records, does not hold a lock since the called is
// assumed to be holding a lock on the mutex and is in charge of making sure the client is initialized.
func (p *Provider) createRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
var createdRecords []libdns.Record

// Get the Zone ID from zone name
resp, err := p.client.Zones.GetZone(ctx, p.AccountID, unFQDN(zone))
Expand All @@ -95,7 +98,7 @@ func (p *Provider) AppendRecords(ctx context.Context, zone string, records []lib
Type: r.Type,
Name: &r.Name,
Content: r.Value,
TTL: int(r.TTL),
TTL: int(r.TTL.Seconds()),
Priority: int(r.Priority),
}
resp, err := p.client.Zones.CreateRecord(ctx, p.AccountID, unFQDN(zone), attrs)
Expand All @@ -105,31 +108,37 @@ func (p *Provider) AppendRecords(ctx context.Context, zone string, records []lib
// See https://developer.dnsimple.com/v2/zones/records/#createZoneRecord
if resp.HTTPResponse.StatusCode == http.StatusCreated {
r.ID = strconv.FormatInt(resp.Data.ID, 10)
appendedRecords = append(appendedRecords, r)
createdRecords = append(createdRecords, r)
} else {
return nil, fmt.Errorf("error creating record: %s, error: %s", r.Name, resp.HTTPResponse.Status)
}
}
return appendedRecords, nil
return createdRecords, nil
}

// SetRecords sets the records in the zone, either by updating existing records or creating new ones.
// It returns the updated records.
func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
// AppendRecords adds records to the zone. It returns the records that were added.
func (p *Provider) AppendRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.initClient(ctx)

var setRecords []libdns.Record
return p.createRecords(ctx, zone, records)
}

existingRecords, err := p.GetRecords(ctx, unFQDN(zone))
// Internal helper function to get the lists of records to create and update respectively
func (p *Provider) getRecordsToCreateAndUpdate(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, []libdns.Record, error) {
existingRecords, err := p.getRecordsFromProvider(ctx, zone)
if err != nil {
return nil, err
return nil, nil, err
}
var recordsToUpdate []libdns.Record

updateMap := make(map[libdns.Record]bool)
var recordsToCreate []libdns.Record

// Figure out which records exist and need to be updated
for i, r := range records {
for _, r := range records {
updateMap[r] = true
for _, er := range existingRecords {
if r.Name != er.Name {
continue
Expand All @@ -138,15 +147,35 @@ func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns
r.ID = er.ID
}
recordsToUpdate = append(recordsToUpdate, r)
// If this is a record that exists and will be updated, remove it from
// the records slice, so everything left will be a record that does not
// exist and needs to be created.
records = append(records[:i], records[i+1:]...)
updateMap[r] = false
}
}
// If the record is not updating an existing record, we want to create it
for r, updating := range updateMap {
if updating {
recordsToCreate = append(recordsToCreate, r)
}
}

return recordsToCreate, recordsToUpdate, nil
}

// SetRecords sets the records in the zone, either by updating existing records or creating new ones.
// It returns the updated records.
func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.initClient(ctx)

var setRecords []libdns.Record

recordsToCreate, recordsToUpdate, err := p.getRecordsToCreateAndUpdate(ctx, zone, records)
if err != nil {
return nil, err
}

// Create new records and append them to 'setRecords'
createdRecords, err := p.AppendRecords(ctx, unFQDN(zone), records)
createdRecords, err := p.createRecords(ctx, zone, recordsToCreate)
if err != nil {
return nil, err
}
Expand All @@ -168,7 +197,7 @@ func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns
Type: r.Type,
Name: &r.Name,
Content: r.Value,
TTL: int(r.TTL),
TTL: int(r.TTL.Seconds()),
Priority: int(r.Priority),
}
id, err := strconv.ParseInt(r.ID, 10, 64)
Expand Down Expand Up @@ -232,12 +261,12 @@ func (p *Provider) DeleteRecords(ctx context.Context, zone string, records []lib
// If we received records without an ID earlier, we're going to try and figure out the ID by calling
// GetRecords and comparing the record name. If we're able to find it, we'll delete it, otherwise
// we'll append it to our list of failed to delete records.
fetchedRecords, err := p.GetRecords(ctx, unFQDN(zone))
existingRecords, err := p.getRecordsFromProvider(ctx, zone)
if err != nil {
return nil, fmt.Errorf("failed to fetch records: %s", err.Error())
return nil, fmt.Errorf("failed to get existing records: %s", err.Error())
}
for _, r := range noID {
for _, fr := range fetchedRecords {
for _, fr := range existingRecords {
if r.Name != fr.Name {
continue
}
Expand Down
Loading

0 comments on commit dfa9220

Please sign in to comment.