diff --git a/containerfs/etc/passwd b/containerfs/etc/passwd new file mode 100644 index 0000000..c075617 --- /dev/null +++ b/containerfs/etc/passwd @@ -0,0 +1 @@ +appuser:x:1001:1001:App User:/:/sbin/nologin \ No newline at end of file diff --git a/cycle_test.go b/cycle_test.go new file mode 100644 index 0000000..e2417f0 --- /dev/null +++ b/cycle_test.go @@ -0,0 +1,271 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/joeig/go-powerdns/v3" +) + +func TestValidateWebhookURL_Empty(t *testing.T) { + if err := validateWebhookURL(""); err != nil { + t.Errorf("expected nil for empty URL, got %v", err) + } +} + +func TestValidateWebhookURL_HTTP(t *testing.T) { + if err := validateWebhookURL("http://example.com/webhook"); err != nil { + t.Errorf("expected nil for http URL, got %v", err) + } +} + +func TestValidateWebhookURL_HTTPS(t *testing.T) { + if err := validateWebhookURL("https://example.com/webhook"); err != nil { + t.Errorf("expected nil for https URL, got %v", err) + } +} + +func TestValidateWebhookURL_BadScheme(t *testing.T) { + err := validateWebhookURL("ftp://example.com/webhook") + if err == nil { + t.Fatal("expected error for ftp scheme") + } + if !strings.Contains(err.Error(), "http or https") { + t.Errorf("expected scheme error, got %v", err) + } +} + +func TestValidateWebhookURL_NoHost(t *testing.T) { + err := validateWebhookURL("https://") + if err == nil { + t.Fatal("expected error for missing host") + } + if !strings.Contains(err.Error(), "host") { + t.Errorf("expected host error, got %v", err) + } +} + +func TestValidateWebhookURL_InvalidURL(t *testing.T) { + err := validateWebhookURL("://bad") + if err == nil { + t.Fatal("expected error for invalid URL") + } +} + +func TestRunCycle_IPChanged_SendsNotification(t *testing.T) { + var ntfyBody string + ntfyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := make([]byte, r.ContentLength) + r.Body.Read(body) + ntfyBody = string(body) + w.WriteHeader(http.StatusOK) + })) + defer ntfyServer.Close() + + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + {Records: []powerdns.Record{{Content: strPtr("1.2.3.4")}}}, + }, + } + + updated, err := runCycle( + context.Background(), + "5.6.7.8", + rrsetSlice{{Name: "myhost.", Zone: "example.com."}}, + mock, + "", + ntfyServer.URL+"/mytopic", + ntfyServer.Client(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !updated { + t.Fatal("expected updated to be true") + } + if !mock.ChangeCalled { + t.Fatal("expected Change to be called since IP differs") + } + if !strings.Contains(ntfyBody, "5.6.7.8") { + t.Errorf("expected ntfy notification to contain IP, got %q", ntfyBody) + } +} + +func TestRunCycle_IPUnchanged_NoNotification(t *testing.T) { + notificationSent := false + discordServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + notificationSent = true + w.WriteHeader(http.StatusNoContent) + })) + defer discordServer.Close() + + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + {Records: []powerdns.Record{{Content: strPtr("1.2.3.4")}}}, + }, + } + + updated, err := runCycle( + context.Background(), + "1.2.3.4", + rrsetSlice{{Name: "myhost.", Zone: "example.com."}}, + mock, + discordServer.URL+"/webhook", + "", + discordServer.Client(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated { + t.Fatal("expected updated to be false since IP unchanged") + } + if notificationSent { + t.Fatal("expected no notification when IP unchanged") + } +} + +func TestRunCycle_UpdateError_SendsCriticalNotification(t *testing.T) { + var ntfyBody string + ntfyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := make([]byte, r.ContentLength) + r.Body.Read(body) + ntfyBody = string(body) + w.WriteHeader(http.StatusOK) + })) + defer ntfyServer.Close() + + mock := &MockRecordsClient{ + GetError: fmt.Errorf("API error"), + } + + _, err := runCycle( + context.Background(), + "5.6.7.8", + rrsetSlice{{Name: "myhost.", Zone: "example.com."}}, + mock, + "", + ntfyServer.URL+"/mytopic", + ntfyServer.Client(), + ) + if err == nil { + t.Fatal("expected error when update fails") + } + if !strings.Contains(ntfyBody, "CRITICAL") { + t.Errorf("expected CRITICAL notification, got %q", ntfyBody) + } +} + +func TestRunCycle_UpdateError_NoWebhooks(t *testing.T) { + mock := &MockRecordsClient{ + GetError: fmt.Errorf("API error"), + } + + _, err := runCycle( + context.Background(), + "5.6.7.8", + rrsetSlice{{Name: "myhost.", Zone: "example.com."}}, + mock, + "", + "", + http.DefaultClient, + ) + if err == nil { + t.Fatal("expected error when update fails") + } +} + +func TestRunCycle_UpdateError_NotificationAlsoFails(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + mock := &MockRecordsClient{ + GetError: fmt.Errorf("API error"), + } + + _, err := runCycle( + context.Background(), + "5.6.7.8", + rrsetSlice{{Name: "myhost.", Zone: "example.com."}}, + mock, + "", + server.URL+"/mytopic", + server.Client(), + ) + if err == nil { + t.Fatal("expected error when update fails") + } +} + +func TestRunCycle_NoUpdatesNoErrors_NoNotification(t *testing.T) { + notificationSent := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + notificationSent = true + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + {Records: []powerdns.Record{{Content: strPtr("1.2.3.4")}}}, + }, + } + + updated, err := runCycle( + context.Background(), + "1.2.3.4", + rrsetSlice{{Name: "myhost.", Zone: "example.com."}}, + mock, + server.URL+"/discord", + server.URL+"/ntfy", + server.Client(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated { + t.Fatal("expected updated to be false since IP unchanged") + } + if notificationSent { + t.Fatal("expected no notification when IP unchanged and no errors") + } +} + +func TestRunCycle_NotificationError_DoesNotBlock(t *testing.T) { + badServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer badServer.Close() + + badClient := badServer.Client() + closeServer := func() { badServer.Close() } + closeServer() + + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + {Records: []powerdns.Record{{Content: strPtr("1.2.3.4")}}}, + }, + } + + updated, err := runCycle( + context.Background(), + "5.6.7.8", + rrsetSlice{{Name: "myhost.", Zone: "example.com."}}, + mock, + "http://127.0.0.1:1/webhook", + "", + badClient, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !updated { + t.Fatal("expected updated to be true even when notification fails") + } +} diff --git a/docker-compose.example.yml b/docker-compose.example.yml new file mode 100644 index 0000000..0f00062 --- /dev/null +++ b/docker-compose.example.yml @@ -0,0 +1,10 @@ +services: + tracker: + image: "kr0nus/iptracker:latest" + command: + - "-b" + - "-i30s" + - '--pdns_url=http://192.168.0.10:8081' + - '--pdns_apikey=super_secret_p@ssword!' + - '-r test.kronus.dev,kronus.dev' + - '-r test2.kronus.dev,kronus.dev' diff --git a/docker-compose.integration.yml b/docker-compose.integration.yml index 3bb7490..5d9838a 100644 --- a/docker-compose.integration.yml +++ b/docker-compose.integration.yml @@ -7,12 +7,6 @@ services: - "8081:8081" - "1053:53" - "1053:53/udp" - healthcheck: - test: ["CMD-SHELL", "curl -sf -H 'X-API-Key: testapikey' http://localhost:8081/api/v1/servers/localhost || exit 1"] - interval: 5s - timeout: 3s - retries: 30 - start_period: 10s ntfy: image: binwiederhier/ntfy:latest diff --git a/dockerfile b/dockerfile index 509b7b0..b9f914d 100644 --- a/dockerfile +++ b/dockerfile @@ -1,9 +1,14 @@ FROM golang:1.26 AS build WORKDIR /app COPY . . -RUN go mod download && go build -o iptracker -v ./... && chmod +x iptracker +RUN go mod download && CGO_ENABLED=0 go build -a -ldflags "-extldflags '-static'" -o iptracker -v ./... + +FROM debian:trixie AS src_os +RUN apt update && apt install -y ca-certificates FROM scratch -USER bot -COPY --from=build /app/iptracker . -CMD ["iptracker"] \ No newline at end of file +COPY containerfs/ . +COPY --from=build --chmod=555 /app/iptracker . +COPY --from=src_os /etc/ssl /etc/ssl +USER appuser +ENTRYPOINT [ "./iptracker" ] \ No newline at end of file diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..c03c95c --- /dev/null +++ b/errors_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "errors" + "fmt" + "testing" +) + +func TestIPCheckError_Error(t *testing.T) { + err := &IPCheckError{Err: fmt.Errorf("some failure")} + msg := err.Error() + expected := "ip check: some failure" + if msg != expected { + t.Errorf("expected %q, got %q", expected, msg) + } +} + +func TestIPCheckError_Unwrap(t *testing.T) { + inner := fmt.Errorf("inner error") + err := &IPCheckError{Err: inner} + if !errors.Is(err, inner) { + t.Fatal("expected errors.Is to match inner error") + } +} + +func TestRecordUpdateError_Error(t *testing.T) { + err := &RecordUpdateError{Record: "myhost", Zone: "example.com", Err: fmt.Errorf("fail")} + msg := err.Error() + expected := `update record "myhost" in zone "example.com": fail` + if msg != expected { + t.Errorf("expected %q, got %q", expected, msg) + } +} + +func TestRecordUpdateError_Unwrap(t *testing.T) { + inner := fmt.Errorf("inner error") + err := &RecordUpdateError{Record: "myhost", Zone: "example.com", Err: inner} + if !errors.Is(err, inner) { + t.Fatal("expected errors.Is to match inner error") + } +} + +func TestNotificationError_Error(t *testing.T) { + err := &NotificationError{Service: "discord", Err: fmt.Errorf("timeout")} + msg := err.Error() + expected := "discord notification: timeout" + if msg != expected { + t.Errorf("expected %q, got %q", expected, msg) + } +} + +func TestNotificationError_Unwrap(t *testing.T) { + inner := fmt.Errorf("inner error") + err := &NotificationError{Service: "ntfy", Err: inner} + if !errors.Is(err, inner) { + t.Fatal("expected errors.Is to match inner error") + } +} diff --git a/integration_test.go b/integration_test.go index 9af88e4..26f7f31 100644 --- a/integration_test.go +++ b/integration_test.go @@ -10,6 +10,8 @@ import ( "io" "net/http" "os" + "os/exec" + "path/filepath" "strings" "testing" "time" @@ -390,3 +392,179 @@ func TestIntegration_MultipleRRSets(t *testing.T) { waitForNtfyMessage(t, ntfyURL, ntfyTopic, "IP address updated: "+testNewIP) } + +func buildBinary(t *testing.T) string { + t.Helper() + tmpDir := t.TempDir() + binary := filepath.Join(tmpDir, "iptracker") + output, err := exec.Command("go", "build", "-o", binary, ".").CombinedOutput() + if err != nil { + t.Fatalf("failed to build binary: %v\n%s", err, output) + } + return binary +} + +func runBinary(t *testing.T, binary string, args ...string) (string, int) { + t.Helper() + cmd := exec.Command(binary, args...) + var stderr bytes.Buffer + cmd.Stderr = &stderr + cmd.Run() + exitCode := 0 + if cmd.ProcessState != nil { + exitCode = cmd.ProcessState.ExitCode() + } + return stderr.String(), exitCode +} + +func TestIntegration_Main_MissingAPIKey(t *testing.T) { + binary := buildBinary(t) + stderr, exitCode := runBinary(t, binary) + if exitCode != 1 { + t.Fatalf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(stderr, "--pdns_apikey is required") { + t.Fatalf("expected error about --pdns_apikey, got:\n%s", stderr) + } +} + +func TestIntegration_Main_MissingURL(t *testing.T) { + binary := buildBinary(t) + stderr, exitCode := runBinary(t, binary, "--pdns_apikey=testkey") + if exitCode != 1 { + t.Fatalf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(stderr, "--pdns_url is required") { + t.Fatalf("expected error about --pdns_url, got:\n%s", stderr) + } +} + +func TestIntegration_Main_MissingRRSet(t *testing.T) { + binary := buildBinary(t) + stderr, exitCode := runBinary(t, binary, "--pdns_apikey=testkey", "--pdns_url=http://localhost:8081") + if exitCode != 1 { + t.Fatalf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(stderr, "at least one --rrset is required") { + t.Fatalf("expected error about --rrset, got:\n%s", stderr) + } +} + +func TestIntegration_Main_NegativeInterval(t *testing.T) { + binary := buildBinary(t) + stderr, exitCode := runBinary(t, binary, + "--pdns_apikey=testkey", + "--pdns_url=http://localhost:8081", + "-r", "myhost.example.com.,example.com.", + "-b", + "-i=-1s", + ) + if exitCode != 1 { + t.Fatalf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(stderr, "--interval must be positive") { + t.Fatalf("expected error about --interval, got:\n%s", stderr) + } +} + +func TestIntegration_Main_InvalidDiscordWebhook(t *testing.T) { + binary := buildBinary(t) + stderr, exitCode := runBinary(t, binary, + "--pdns_apikey=testkey", + "--pdns_url=http://localhost:8081", + "-r", "myhost.example.com.,example.com.", + "--discord=ftp://bad", + ) + if exitCode != 1 { + t.Fatalf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(stderr, "invalid --discord") { + t.Fatalf("expected error about --discord, got:\n%s", stderr) + } +} + +func TestIntegration_Main_InvalidNtfyWebhook(t *testing.T) { + binary := buildBinary(t) + stderr, exitCode := runBinary(t, binary, + "--pdns_apikey=testkey", + "--pdns_url=http://localhost:8081", + "-r", "myhost.example.com.,example.com.", + "--ntfy=ftp://bad", + ) + if exitCode != 1 { + t.Fatalf("expected exit code 1, got %d", exitCode) + } + if !strings.Contains(stderr, "invalid --ntfy") { + t.Fatalf("expected error about --ntfy, got:\n%s", stderr) + } +} + +func TestIntegration_Client_NewRecordsClient(t *testing.T) { + pdnsURL := envOrDefault("PDNS_URL", defaultPDNSURL) + pdnsKey := envOrDefault("PDNS_API_KEY", defaultPDNSKey) + waitForPowerDNS(t, pdnsKey, pdnsURL) + + httpClient := newTestHTTPClient() + client := newRecordsClient(pdnsKey, pdnsURL, httpClient) + if client == nil { + t.Fatal("expected non-nil RecordsClient") + } +} + +func TestIntegration_Client_Get(t *testing.T) { + recordsClient, _, _, cleanup := setupIntegration(t) + defer cleanup() + + pdnsURL := envOrDefault("PDNS_URL", defaultPDNSURL) + pdnsKey := envOrDefault("PDNS_API_KEY", defaultPDNSKey) + waitForPowerDNS(t, pdnsKey, pdnsURL) + + httpClient := newTestHTTPClient() + concreteClient := newRecordsClient(pdnsKey, pdnsURL, httpClient) + + rrsets, err := concreteClient.Get(context.Background(), testZone, testRecord, powerdns.RRTypePtr(powerdns.RRTypeA)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(rrsets) != 1 { + t.Fatalf("expected 1 rrset, got %d", len(rrsets)) + } + if len(rrsets[0].Records) != 1 || rrsets[0].Records[0].Content == nil { + t.Fatal("expected 1 record with non-nil content") + } + if *rrsets[0].Records[0].Content != testInitialIP { + t.Errorf("expected IP %s, got %s", testInitialIP, *rrsets[0].Records[0].Content) + } + + _ = recordsClient +} + +func TestIntegration_Client_Change(t *testing.T) { + pdnsURL := envOrDefault("PDNS_URL", defaultPDNSURL) + pdnsKey := envOrDefault("PDNS_API_KEY", defaultPDNSKey) + waitForPowerDNS(t, pdnsKey, pdnsURL) + + zone := "clienttest.example.net." + seedZone(t, pdnsKey, pdnsURL, zone) + defer deleteZone(t, pdnsKey, pdnsURL, zone) + + httpClient := newTestHTTPClient() + concreteClient := newRecordsClient(pdnsKey, pdnsURL, httpClient) + record := "test." + zone + + err := concreteClient.Change(context.Background(), zone, record, powerdns.RRTypeA, 60, []string{testInitialIP}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + rrsets, err := concreteClient.Get(context.Background(), zone, record, powerdns.RRTypePtr(powerdns.RRTypeA)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(rrsets) == 0 || len(rrsets[0].Records) == 0 || rrsets[0].Records[0].Content == nil { + t.Fatal("expected record to exist after Change") + } + if *rrsets[0].Records[0].Content != testInitialIP { + t.Errorf("expected IP %s, got %s", testInitialIP, *rrsets[0].Records[0].Content) + } +} diff --git a/ip.go b/ip.go index fbbbc95..ea36c67 100644 --- a/ip.go +++ b/ip.go @@ -11,8 +11,12 @@ import ( const maxIPResponseSize = 45 -func getLiveIP(ctx context.Context, httpClient *http.Client) (string, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://ifconfig.me", nil) +const defaultIPCheckURL = "https://ifconfig.me" + +// getLiveIP fetches the public IPv4 address from the provided URL, validates that the response is a syntactically valid IPv4 address, and returns it. +// On failure it returns an empty string and an *IPCheckError describing the cause. +func getLiveIP(ctx context.Context, httpClient *http.Client, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return "", &IPCheckError{Err: fmt.Errorf("create request: %w", err)} } diff --git a/ip_test.go b/ip_test.go new file mode 100644 index 0000000..0048871 --- /dev/null +++ b/ip_test.go @@ -0,0 +1,196 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" +) + +type mockTransport struct { + statusCode int + body string + bodyReader io.ReadCloser + err error +} + +func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if t.err != nil { + return nil, t.err + } + body := t.bodyReader + if body == nil { + body = io.NopCloser(strings.NewReader(t.body)) + } + return &http.Response{ + StatusCode: t.statusCode, + Status: http.StatusText(t.statusCode), + Body: body, + Header: make(http.Header), + }, nil +} + +func TestGetLiveIP_Success(t *testing.T) { + client := &http.Client{Transport: &mockTransport{statusCode: 200, body: "1.2.3.4"}} + ip, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "1.2.3.4" { + t.Errorf("expected 1.2.3.4, got %s", ip) + } +} + +func TestGetLiveIP_WhitespaceTrimmed(t *testing.T) { + client := &http.Client{Transport: &mockTransport{statusCode: 200, body: " 10.0.0.1\n "}} + ip, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "10.0.0.1" { + t.Errorf("expected 10.0.0.1, got %s", ip) + } +} + +func TestGetLiveIP_Non200Status(t *testing.T) { + client := &http.Client{Transport: &mockTransport{statusCode: 500, body: "error"}} + _, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err == nil { + t.Fatal("expected error for non-200 status") + } + var ipErr *IPCheckError + if !errors.As(err, &ipErr) { + t.Fatalf("expected IPCheckError, got %T", err) + } +} + +func TestGetLiveIP_InvalidIPv4(t *testing.T) { + client := &http.Client{Transport: &mockTransport{statusCode: 200, body: "not-an-ip"}} + _, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err == nil { + t.Fatal("expected error for invalid IP") + } + var ipErr *IPCheckError + if !errors.As(err, &ipErr) { + t.Fatalf("expected IPCheckError, got %T", err) + } +} + +func TestGetLiveIP_IPv6Rejected(t *testing.T) { + client := &http.Client{Transport: &mockTransport{statusCode: 200, body: "::1"}} + _, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err == nil { + t.Fatal("expected error for IPv6 address") + } +} + +func TestGetLiveIP_RequestError(t *testing.T) { + client := &http.Client{Transport: &mockTransport{err: &net.DNSError{Err: "dns failure"}}} + _, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err == nil { + t.Fatal("expected error for request failure") + } + var ipErr *IPCheckError + if !errors.As(err, &ipErr) { + t.Fatalf("expected IPCheckError, got %T", err) + } +} + +func TestGetLiveIP_InvalidURL(t *testing.T) { + client := &http.Client{Transport: &mockTransport{statusCode: 200, body: "1.2.3.4"}} + _, err := getLiveIP(context.Background(), client, "://bad") + if err == nil { + t.Fatal("expected error for invalid URL") + } + var ipErr *IPCheckError + if !errors.As(err, &ipErr) { + t.Fatalf("expected IPCheckError, got %T", err) + } + if ipErr.Err == nil { + t.Fatal("expected IPCheckError.Err to be non-nil") + } +} + +func TestGetLiveIP_UserAgentSet(t *testing.T) { + var capturedUA string + transport := &roundTripCapturer{ + statusCode: 200, + body: "1.2.3.4", + capture: func(req *http.Request) { + capturedUA = req.Header.Get("User-Agent") + }, + } + client := &http.Client{Transport: transport} + _, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if capturedUA != "curl/8.12.1" { + t.Errorf("expected User-Agent 'curl/8.12.1', got %q", capturedUA) + } +} + +func TestGetLiveIP_MethodIsGET(t *testing.T) { + var capturedMethod string + transport := &roundTripCapturer{ + statusCode: 200, + body: "1.2.3.4", + capture: func(req *http.Request) { + capturedMethod = req.Method + }, + } + client := &http.Client{Transport: transport} + _, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if capturedMethod != http.MethodGet { + t.Errorf("expected GET method, got %q", capturedMethod) + } +} + +type readErrReader struct{} + +func (r *readErrReader) Read(p []byte) (int, error) { + return 0, fmt.Errorf("read error") +} + +func (r *readErrReader) Close() error { + return nil +} + +func TestGetLiveIP_ReadError(t *testing.T) { + client := &http.Client{Transport: &mockTransport{ + statusCode: 200, + body: "", + bodyReader: io.NopCloser(&readErrReader{}), + }} + _, err := getLiveIP(context.Background(), client, defaultIPCheckURL) + if err == nil { + t.Fatal("expected error for read failure") + } + var ipErr *IPCheckError + if !errors.As(err, &ipErr) { + t.Fatalf("expected IPCheckError, got %T", err) + } +} + +type roundTripCapturer struct { + statusCode int + body string + capture func(*http.Request) +} + +func (t *roundTripCapturer) RoundTrip(req *http.Request) (*http.Response, error) { + t.capture(req) + return &http.Response{ + StatusCode: t.statusCode, + Status: http.StatusText(t.statusCode), + Body: io.NopCloser(strings.NewReader(t.body)), + Header: make(http.Header), + }, nil +} diff --git a/main.go b/main.go index c47471d..891980c 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,13 @@ import ( "github.com/spf13/pflag" ) +// main parses command-line flags, validates required configuration, prepares networking and signal handling, +// and then executes the record-sync cycle once or repeatedly in daemon mode. +// +// It requires PowerDNS API key and URL and at least one --rrset value, validates optional webhook URLs, +// and enforces a positive interval when running as a background daemon. It configures an HTTP client, +// establishes cancellation on SIGINT/SIGTERM, and calls runOnce immediately; when in daemon mode it +// repeats runOnce on each tick until shutdown. func main() { pdnsApiKey := pflag.String("pdns_apikey", "", "Set the PowerDNS API Key") pdnsApiUrl := pflag.String("pdns_url", "", "Set the PowerDNS API URL") @@ -84,8 +91,13 @@ func main() { } } +// runOnce performs a single external-IP check and executes one reconciliation cycle +// for the provided record sets using the PowerDNS API. If the live IP cannot be +// obtained it logs a warning and returns; otherwise it creates a records client +// and runs a processing cycle that reconciles DNS records with the retrieved IP +// and sends notifications to the configured webhooks. func runOnce(ctx context.Context, httpClient *http.Client, pdnsApiKey, pdnsApiUrl, discord, ntfy string, recordSets rrsetSlice) { - liveIP, err := getLiveIP(ctx, httpClient) + liveIP, err := getLiveIP(ctx, httpClient, defaultIPCheckURL) if err != nil { slog.Warn("failed to get live IP", "error", err) return diff --git a/main_test.go b/main_test.go index e3b3da5..9395b55 100644 --- a/main_test.go +++ b/main_test.go @@ -2,9 +2,6 @@ package main import ( "context" - "errors" - "fmt" - "testing" "github.com/joeig/go-powerdns/v3" ) @@ -48,283 +45,3 @@ func (m *MockRecordsClient) Change(ctx context.Context, domain string, name stri func strPtr(s string) *string { return &s } - -func TestUpdateRecord_IPChanged(t *testing.T) { - mock := &MockRecordsClient{ - GetResult: []powerdns.RRset{ - { - Records: []powerdns.Record{ - {Content: strPtr("1.2.3.4")}, - }, - }, - }, - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !changed { - t.Fatal("expected changed to be true") - } - if !mock.GetCalled { - t.Fatal("expected Get to be called") - } - if !mock.ChangeCalled { - t.Fatal("expected Change to be called") - } - if len(mock.ChangeCalls) != 1 { - t.Fatalf("expected 1 Change call, got %d", len(mock.ChangeCalls)) - } - call := mock.ChangeCalls[0] - if call.domain != "example.com" { - t.Errorf("expected domain example.com, got %s", call.domain) - } - if call.name != "myhost" { - t.Errorf("expected name myhost, got %s", call.name) - } - if call.ttl != 60 { - t.Errorf("expected ttl 60, got %d", call.ttl) - } - if len(call.content) != 1 || call.content[0] != "5.6.7.8" { - t.Errorf("expected content [5.6.7.8], got %v", call.content) - } -} - -func TestUpdateRecord_IPUnchanged(t *testing.T) { - mock := &MockRecordsClient{ - GetResult: []powerdns.RRset{ - { - Records: []powerdns.Record{ - {Content: strPtr("1.2.3.4")}, - }, - }, - }, - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "1.2.3.4", mock) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if changed { - t.Fatal("expected changed to be false since IP unchanged") - } - if !mock.GetCalled { - t.Fatal("expected Get to be called") - } - if mock.ChangeCalled { - t.Fatal("expected Change NOT to be called when IP is unchanged") - } -} - -func TestUpdateRecord_GetError(t *testing.T) { - mock := &MockRecordsClient{ - GetError: fmt.Errorf("API error"), - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) - if err == nil { - t.Fatal("expected error on Get failure") - } - if changed { - t.Fatal("expected changed to be false on Get failure") - } - if mock.ChangeCalled { - t.Fatal("expected Change NOT to be called on Get failure") - } - var rerr *RecordUpdateError - if !errors.As(err, &rerr) { - t.Fatalf("expected RecordUpdateError, got %T", err) - } - if rerr.Record != "myhost" || rerr.Zone != "example.com" { - t.Errorf("expected record=myhost zone=example.com, got record=%s zone=%s", rerr.Record, rerr.Zone) - } -} - -func TestUpdateRecord_MultipleRRSets(t *testing.T) { - mock := &MockRecordsClient{ - GetResult: []powerdns.RRset{ - {Records: []powerdns.Record{{Content: strPtr("1.2.3.4")}}}, - {Records: []powerdns.Record{{Content: strPtr("5.6.7.8")}}}, - }, - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "9.9.9.9", mock) - if err == nil { - t.Fatal("expected error when multiple rrsets returned") - } - if changed { - t.Fatal("expected changed to be false") - } -} - -func TestUpdateRecord_MultipleRecordsInRRSet(t *testing.T) { - mock := &MockRecordsClient{ - GetResult: []powerdns.RRset{ - { - Records: []powerdns.Record{ - {Content: strPtr("1.2.3.4")}, - {Content: strPtr("5.6.7.8")}, - }, - }, - }, - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "9.9.9.9", mock) - if err == nil { - t.Fatal("expected error when multiple records in rrset") - } - if changed { - t.Fatal("expected changed to be false") - } -} - -func TestUpdateRecord_ChangeError(t *testing.T) { - mock := &MockRecordsClient{ - GetResult: []powerdns.RRset{ - { - Records: []powerdns.Record{ - {Content: strPtr("1.2.3.4")}, - }, - }, - }, - ChangeError: fmt.Errorf("change failed"), - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) - if err == nil { - t.Fatal("expected error when Change fails") - } - if changed { - t.Fatal("expected changed to be false on Change failure") - } -} - -func TestUpdateRecord_EmptyRRSets(t *testing.T) { - mock := &MockRecordsClient{ - GetResult: []powerdns.RRset{}, - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) - if err == nil { - t.Fatal("expected error when no rrsets returned") - } - if changed { - t.Fatal("expected changed to be false") - } -} - -func TestUpdateRecord_NilContent(t *testing.T) { - mock := &MockRecordsClient{ - GetResult: []powerdns.RRset{ - { - Records: []powerdns.Record{ - {Content: nil}, - }, - }, - }, - } - - changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) - if err == nil { - t.Fatal("expected error when Content is nil") - } - if changed { - t.Fatal("expected changed to be false") - } - if mock.ChangeCalled { - t.Fatal("expected Change NOT to be called when Content is nil") - } -} - -func TestIPCheckError_Unwrap(t *testing.T) { - inner := fmt.Errorf("inner error") - err := &IPCheckError{Err: inner} - if !errors.Is(err, inner) { - t.Fatal("expected errors.Is to match inner error") - } -} - -func TestRecordUpdateError_Unwrap(t *testing.T) { - inner := fmt.Errorf("inner error") - err := &RecordUpdateError{Record: "myhost", Zone: "example.com", Err: inner} - if !errors.Is(err, inner) { - t.Fatal("expected errors.Is to match inner error") - } -} - -func TestNotificationError_Unwrap(t *testing.T) { - inner := fmt.Errorf("inner error") - err := &NotificationError{Service: "discord", Err: inner} - if !errors.Is(err, inner) { - t.Fatal("expected errors.Is to match inner error") - } -} - -func TestRRSetSlice_ValidInput(t *testing.T) { - var r rrsetSlice - if err := r.Set("myhost,example.com"); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if err := r.Set(" otherhost , otherzone.io "); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(r) != 2 { - t.Fatalf("expected 2 rrsets, got %d", len(r)) - } - if r[0].Name != "myhost" || r[0].Zone != "example.com" { - t.Errorf("expected myhost/example.com, got %s/%s", r[0].Name, r[0].Zone) - } - if r[1].Name != "otherhost" || r[1].Zone != "otherzone.io" { - t.Errorf("expected otherhost/otherzone.io, got %s/%s", r[1].Name, r[1].Zone) - } -} - -func TestRRSetSlice_NoComma(t *testing.T) { - var r rrsetSlice - err := r.Set("myhost") - if err == nil { - t.Fatal("expected error for input without comma") - } -} - -func TestRRSetSlice_EmptyName(t *testing.T) { - var r rrsetSlice - err := r.Set(",example.com") - if err == nil { - t.Fatal("expected error for empty record name") - } -} - -func TestRRSetSlice_EmptyZone(t *testing.T) { - var r rrsetSlice - err := r.Set("myhost,") - if err == nil { - t.Fatal("expected error for empty zone") - } -} - -func TestRRSetSlice_CommaInZone(t *testing.T) { - var r rrsetSlice - if err := r.Set("myhost,example.com,extra"); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if r[0].Zone != "example.com,extra" { - t.Errorf("expected zone 'example.com,extra', got %s", r[0].Zone) - } -} - -func TestRRSetSlice_Type(t *testing.T) { - var r rrsetSlice - if r.Type() != "record,zone" { - t.Errorf("expected type 'record,zone', got %s", r.Type()) - } -} - -func TestRRSetSlice_String(t *testing.T) { - r := rrsetSlice{{Name: "a", Zone: "b.com"}, {Name: "c", Zone: "d.org"}} - expected := "a,b.com; c,d.org" - if r.String() != expected { - t.Errorf("expected %q, got %q", expected, r.String()) - } -} diff --git a/notify_test.go b/notify_test.go new file mode 100644 index 0000000..e979875 --- /dev/null +++ b/notify_test.go @@ -0,0 +1,278 @@ +package main + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNotifyDiscord_Success(t *testing.T) { + var receivedBody map[string]string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("expected application/json content-type, got %s", r.Header.Get("Content-Type")) + } + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + err := notifyDiscord("test message", server.URL+"/webhook", server.Client()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if receivedBody["content"] != "test message" { + t.Errorf("expected content 'test message', got %q", receivedBody["content"]) + } +} + +func TestNotifyDiscord_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + defer server.Close() + + client := server.Client() + url := server.URL + server.Close() + + err := notifyDiscord("test", url+"/webhook", client) + if err == nil { + t.Fatal("expected error when server is unreachable") + } +} + +func TestNotifyDiscord_BadStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer server.Close() + + err := notifyDiscord("test", server.URL+"/webhook", server.Client()) + if err == nil { + t.Fatal("expected error for bad status code") + } + var nerr *NotificationError + if !errors.As(err, &nerr) { + t.Fatalf("expected NotificationError, got %T", err) + } + if nerr.Service != "discord" { + t.Errorf("expected service discord, got %s", nerr.Service) + } +} + +func TestNotifyNtfy_Success(t *testing.T) { + var receivedBody string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "text/plain" { + t.Errorf("expected text/plain content-type, got %s", r.Header.Get("Content-Type")) + } + body, _ := io.ReadAll(r.Body) + receivedBody = string(body) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + err := notifyNtfy("test message", server.URL+"/mytopic", server.Client()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if receivedBody != "test message" { + t.Errorf("expected body 'test message', got %q", receivedBody) + } +} + +func TestNotifyNtfy_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer server.Close() + + client := server.Client() + url := server.URL + server.Close() + + err := notifyNtfy("test", url+"/webhook", client) + if err == nil { + t.Fatal("expected error when server is unreachable") + } +} + +func TestNotifyNtfy_BadStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := notifyNtfy("test", server.URL+"/webhook", server.Client()) + if err == nil { + t.Fatal("expected error for bad status code") + } + var nerr *NotificationError + if !errors.As(err, &nerr) { + t.Fatalf("expected NotificationError, got %T", err) + } + if nerr.Service != "ntfy" { + t.Errorf("expected service ntfy, got %s", nerr.Service) + } +} + +func TestSendNotifications_NoWebhooks(t *testing.T) { + err := sendNotifications("test", "", "", http.DefaultClient) + if err != nil { + t.Fatalf("expected nil when no webhooks configured, got %v", err) + } +} + +func TestSendNotifications_DiscordOnly(t *testing.T) { + discordCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + discordCalled = true + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + err := sendNotifications("test", server.URL+"/webhook", "", server.Client()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !discordCalled { + t.Fatal("expected discord to be called") + } +} + +func TestSendNotifications_NtfyOnly(t *testing.T) { + ntfyCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ntfyCalled = true + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + err := sendNotifications("test", "", server.URL+"/mytopic", server.Client()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ntfyCalled { + t.Fatal("expected ntfy to be called") + } +} + +func TestSendNotifications_Both(t *testing.T) { + var calls []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, r.URL.Path) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + err := sendNotifications("test", server.URL+"/discord", server.URL+"/ntfy", server.Client()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(calls) != 2 { + t.Fatalf("expected 2 calls, got %d", len(calls)) + } +} + +func TestSendNotifications_DiscordFails(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := sendNotifications("test", server.URL+"/webhook", "", server.Client()) + if err == nil { + t.Fatal("expected error when discord fails") + } + var nerr *NotificationError + if !errors.As(err, &nerr) { + t.Fatalf("expected NotificationError, got %T", err) + } +} + +func TestSendNotifications_NtfyFails(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := sendNotifications("test", "", server.URL+"/topic", server.Client()) + if err == nil { + t.Fatal("expected error when ntfy fails") + } + var nerr *NotificationError + if !errors.As(err, &nerr) { + t.Fatalf("expected NotificationError, got %T", err) + } +} + +func TestSendNotifications_BothFail(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := sendNotifications("test", server.URL+"/discord", server.URL+"/ntfy", server.Client()) + if err == nil { + t.Fatal("expected error when both fail") + } + + var discordErr, ntfyErr *NotificationError + if !errors.As(err, &discordErr) { + t.Fatalf("expected discord NotificationError in joined error, got %T", err) + } + if !errors.As(err, &ntfyErr) { + t.Fatalf("expected ntfy NotificationError in joined error, got %T", err) + } +} + +func TestNotifyDiscord_PayloadFormat(t *testing.T) { + var contentType string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + contentType = r.Header.Get("Content-Type") + io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + err := notifyDiscord("hello world", server.URL+"/webhook", server.Client()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(contentType, "application/json") { + t.Errorf("expected application/json content-type, got %s", contentType) + } +} + +func TestNotifyNtfy_PayloadFormat(t *testing.T) { + var contentType string + var bodyStr string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + contentType = r.Header.Get("Content-Type") + body, _ := io.ReadAll(r.Body) + bodyStr = string(body) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + err := notifyNtfy("hello world", server.URL+"/topic", server.Client()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(contentType, "text/plain") { + t.Errorf("expected text/plain content-type, got %s", contentType) + } + if bodyStr != "hello world" { + t.Errorf("expected body 'hello world', got %q", bodyStr) + } +} diff --git a/rrset_test.go b/rrset_test.go new file mode 100644 index 0000000..108e1fa --- /dev/null +++ b/rrset_test.go @@ -0,0 +1,73 @@ +package main + +import ( + "testing" +) + +func TestRRSetSlice_ValidInput(t *testing.T) { + var r rrsetSlice + if err := r.Set("myhost,example.com"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if err := r.Set(" otherhost , otherzone.io "); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(r) != 2 { + t.Fatalf("expected 2 rrsets, got %d", len(r)) + } + if r[0].Name != "myhost" || r[0].Zone != "example.com" { + t.Errorf("expected myhost/example.com, got %s/%s", r[0].Name, r[0].Zone) + } + if r[1].Name != "otherhost" || r[1].Zone != "otherzone.io" { + t.Errorf("expected otherhost/otherzone.io, got %s/%s", r[1].Name, r[1].Zone) + } +} + +func TestRRSetSlice_NoComma(t *testing.T) { + var r rrsetSlice + err := r.Set("myhost") + if err == nil { + t.Fatal("expected error for input without comma") + } +} + +func TestRRSetSlice_EmptyName(t *testing.T) { + var r rrsetSlice + err := r.Set(",example.com") + if err == nil { + t.Fatal("expected error for empty record name") + } +} + +func TestRRSetSlice_EmptyZone(t *testing.T) { + var r rrsetSlice + err := r.Set("myhost,") + if err == nil { + t.Fatal("expected error for empty zone") + } +} + +func TestRRSetSlice_CommaInZone(t *testing.T) { + var r rrsetSlice + if err := r.Set("myhost,example.com,extra"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if r[0].Zone != "example.com,extra" { + t.Errorf("expected zone 'example.com,extra', got %s", r[0].Zone) + } +} + +func TestRRSetSlice_Type(t *testing.T) { + var r rrsetSlice + if r.Type() != "record,zone" { + t.Errorf("expected type 'record,zone', got %s", r.Type()) + } +} + +func TestRRSetSlice_String(t *testing.T) { + r := rrsetSlice{{Name: "a", Zone: "b.com"}, {Name: "c", Zone: "d.org"}} + expected := "a,b.com; c,d.org" + if r.String() != expected { + t.Errorf("expected %q, got %q", expected, r.String()) + } +} diff --git a/update_test.go b/update_test.go new file mode 100644 index 0000000..d141331 --- /dev/null +++ b/update_test.go @@ -0,0 +1,198 @@ +package main + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/joeig/go-powerdns/v3" +) + +func TestUpdateRecord_IPChanged(t *testing.T) { + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + { + Records: []powerdns.Record{ + {Content: strPtr("1.2.3.4")}, + }, + }, + }, + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !changed { + t.Fatal("expected changed to be true") + } + if !mock.GetCalled { + t.Fatal("expected Get to be called") + } + if !mock.ChangeCalled { + t.Fatal("expected Change to be called") + } + if len(mock.ChangeCalls) != 1 { + t.Fatalf("expected 1 Change call, got %d", len(mock.ChangeCalls)) + } + call := mock.ChangeCalls[0] + if call.domain != "example.com" { + t.Errorf("expected domain example.com, got %s", call.domain) + } + if call.name != "myhost" { + t.Errorf("expected name myhost, got %s", call.name) + } + if call.ttl != 60 { + t.Errorf("expected ttl 60, got %d", call.ttl) + } + if len(call.content) != 1 || call.content[0] != "5.6.7.8" { + t.Errorf("expected content [5.6.7.8], got %v", call.content) + } +} + +func TestUpdateRecord_IPUnchanged(t *testing.T) { + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + { + Records: []powerdns.Record{ + {Content: strPtr("1.2.3.4")}, + }, + }, + }, + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "1.2.3.4", mock) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if changed { + t.Fatal("expected changed to be false since IP unchanged") + } + if !mock.GetCalled { + t.Fatal("expected Get to be called") + } + if mock.ChangeCalled { + t.Fatal("expected Change NOT to be called when IP is unchanged") + } +} + +func TestUpdateRecord_GetError(t *testing.T) { + mock := &MockRecordsClient{ + GetError: fmt.Errorf("API error"), + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) + if err == nil { + t.Fatal("expected error on Get failure") + } + if changed { + t.Fatal("expected changed to be false on Get failure") + } + if mock.ChangeCalled { + t.Fatal("expected Change NOT to be called on Get failure") + } + var rerr *RecordUpdateError + if !errors.As(err, &rerr) { + t.Fatalf("expected RecordUpdateError, got %T", err) + } + if rerr.Record != "myhost" || rerr.Zone != "example.com" { + t.Errorf("expected record=myhost zone=example.com, got record=%s zone=%s", rerr.Record, rerr.Zone) + } +} + +func TestUpdateRecord_MultipleRRSets(t *testing.T) { + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + {Records: []powerdns.Record{{Content: strPtr("1.2.3.4")}}}, + {Records: []powerdns.Record{{Content: strPtr("5.6.7.8")}}}, + }, + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "9.9.9.9", mock) + if err == nil { + t.Fatal("expected error when multiple rrsets returned") + } + if changed { + t.Fatal("expected changed to be false") + } +} + +func TestUpdateRecord_MultipleRecordsInRRSet(t *testing.T) { + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + { + Records: []powerdns.Record{ + {Content: strPtr("1.2.3.4")}, + {Content: strPtr("5.6.7.8")}, + }, + }, + }, + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "9.9.9.9", mock) + if err == nil { + t.Fatal("expected error when multiple records in rrset") + } + if changed { + t.Fatal("expected changed to be false") + } +} + +func TestUpdateRecord_ChangeError(t *testing.T) { + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + { + Records: []powerdns.Record{ + {Content: strPtr("1.2.3.4")}, + }, + }, + }, + ChangeError: fmt.Errorf("change failed"), + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) + if err == nil { + t.Fatal("expected error when Change fails") + } + if changed { + t.Fatal("expected changed to be false on Change failure") + } +} + +func TestUpdateRecord_EmptyRRSets(t *testing.T) { + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{}, + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) + if err == nil { + t.Fatal("expected error when no rrsets returned") + } + if changed { + t.Fatal("expected changed to be false") + } +} + +func TestUpdateRecord_NilContent(t *testing.T) { + mock := &MockRecordsClient{ + GetResult: []powerdns.RRset{ + { + Records: []powerdns.Record{ + {Content: nil}, + }, + }, + }, + } + + changed, err := updateRecord(context.Background(), RRSet{Name: "myhost", Zone: "example.com"}, "5.6.7.8", mock) + if err == nil { + t.Fatal("expected error when Content is nil") + } + if changed { + t.Fatal("expected changed to be false") + } + if mock.ChangeCalled { + t.Fatal("expected Change NOT to be called when Content is nil") + } +}