package sqs import ( "context" "encoding/json" "errors" "fmt" "log" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" ) // Message represents a message from SQS queue type Message struct { ID string Body string ReceiptHandle string RawEventID string CorrelationID *string // Optional correlation ID from message attributes } // RawEventMessage represents the expected structure of SQS message body type RawEventMessage struct { RawEventID string `json:"raw_event_id"` Timestamp string `json:"timestamp,omitempty"` Source string `json:"source,omitempty"` } // Client wraps the AWS SQS client with our specific functionality // It implements the SQSClient interface type Client struct { sqsClient *sqs.Client queueURL string maxMessages int32 waitTimeSeconds int32 visibilityTimeout int32 region string } // NewClient creates a new SQS client func NewClient(region, queueURL string, maxMessages, waitTimeSeconds, visibilityTimeout int32) (*Client, error) { // Load AWS configuration cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region), ) if err != nil { return nil, fmt.Errorf("failed to load AWS config: %w", err) } // Create SQS client sqsClient := sqs.NewFromConfig(cfg) return &Client{ sqsClient: sqsClient, queueURL: queueURL, maxMessages: maxMessages, waitTimeSeconds: waitTimeSeconds, visibilityTimeout: visibilityTimeout, region: region, }, nil } // ReceiveMessages polls the SQS queue for messages func (c *Client) ReceiveMessages(ctx context.Context) ([]*Message, error) { input := &sqs.ReceiveMessageInput{ QueueUrl: &c.queueURL, MaxNumberOfMessages: c.maxMessages, WaitTimeSeconds: c.waitTimeSeconds, VisibilityTimeout: c.visibilityTimeout, MessageAttributeNames: []string{"All"}, } result, err := c.sqsClient.ReceiveMessage(ctx, input) if err != nil { return nil, fmt.Errorf("failed to receive messages from SQS: %w", err) } messages := make([]*Message, 0, len(result.Messages)) for _, sqsMsg := range result.Messages { msg, err := c.parseMessage(sqsMsg) if err != nil { log.Printf("Error parsing SQS message %s: %v", aws.ToString(sqsMsg.MessageId), err) // Continue processing other messages even if one fails to parse continue } messages = append(messages, msg) } return messages, nil } // parseMessage converts an SQS message to our internal Message structure func (c *Client) parseMessage(sqsMsg types.Message) (*Message, error) { if sqsMsg.MessageId == nil { return nil, errors.New("message ID is nil") } if sqsMsg.Body == nil { return nil, errors.New("message body is nil") } if sqsMsg.ReceiptHandle == nil { return nil, errors.New("receipt handle is nil") } // Parse the message body as JSON to extract raw_event_id var rawEventMsg RawEventMessage if err := json.Unmarshal([]byte(*sqsMsg.Body), &rawEventMsg); err != nil { return nil, fmt.Errorf("failed to parse message body as JSON: %w", err) } if rawEventMsg.RawEventID == "" { return nil, errors.New("raw_event_id is missing from message body") } // Extract correlation_id from message attributes if present var correlationID *string if sqsMsg.MessageAttributes != nil { if attr, ok := sqsMsg.MessageAttributes["correlation_id"]; ok && attr.StringValue != nil { correlationID = attr.StringValue } // Also check for x-correlation-id (alternative naming) if correlationID == nil { if attr, ok := sqsMsg.MessageAttributes["x-correlation-id"]; ok && attr.StringValue != nil { correlationID = attr.StringValue } } } return &Message{ ID: *sqsMsg.MessageId, Body: *sqsMsg.Body, ReceiptHandle: *sqsMsg.ReceiptHandle, RawEventID: rawEventMsg.RawEventID, CorrelationID: correlationID, }, nil } // DeleteMessage removes a message from the queue after successful processing func (c *Client) DeleteMessage(ctx context.Context, receiptHandle string) error { input := &sqs.DeleteMessageInput{ QueueUrl: &c.queueURL, ReceiptHandle: &receiptHandle, } _, err := c.sqsClient.DeleteMessage(ctx, input) if err != nil { return fmt.Errorf("failed to delete message from SQS: %w", err) } return nil } // ChangeMessageVisibility extends the visibility timeout for a message // This is useful when processing takes longer than expected func (c *Client) ChangeMessageVisibility(ctx context.Context, receiptHandle string, visibilityTimeout int32) error { input := &sqs.ChangeMessageVisibilityInput{ QueueUrl: &c.queueURL, ReceiptHandle: &receiptHandle, VisibilityTimeout: visibilityTimeout, } _, err := c.sqsClient.ChangeMessageVisibility(ctx, input) if err != nil { return fmt.Errorf("failed to change message visibility: %w", err) } return nil } // PollMessages continuously polls for messages and sends them to a channel func (c *Client) PollMessages(ctx context.Context, messagesChan chan<- *Message, errorsChan chan<- error) { log.Printf("Starting SQS polling for queue: %s", c.queueURL) ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): log.Println("SQS polling stopped due to context cancellation") return case <-ticker.C: messages, err := c.ReceiveMessages(ctx) if err != nil { log.Printf("Error receiving messages: %v", err) errorsChan <- err continue } if len(messages) > 0 { log.Printf("Received %d messages from SQS", len(messages)) } for _, msg := range messages { select { case messagesChan <- msg: log.Printf("Sent message %s to processing channel", msg.ID) case <-ctx.Done(): return } } } } } // GetQueueAttributes retrieves queue attributes for monitoring func (c *Client) GetQueueAttributes(ctx context.Context) (map[string]string, error) { input := &sqs.GetQueueAttributesInput{ QueueUrl: &c.queueURL, AttributeNames: []types.QueueAttributeName{ types.QueueAttributeNameApproximateNumberOfMessages, types.QueueAttributeNameApproximateNumberOfMessagesNotVisible, types.QueueAttributeNameApproximateNumberOfMessagesDelayed, }, } result, err := c.sqsClient.GetQueueAttributes(ctx, input) if err != nil { return nil, fmt.Errorf("failed to get queue attributes: %w", err) } return result.Attributes, nil }