package wsstomp import ( "context" "fmt" "net/http" "time" "github.com/coder/websocket" ) type WebsocketSTOMP struct { connection *websocket.Conn readerBuffer []byte writeBuffer []byte } const ( NullByte = 0x00 LineFeedByte = 0x0a writeTimeout = 10 * time.Second ) // Read reads messages from the websocket connection until the provided array is full. // Any surplus data is preserved for the next Read call. func (w *WebsocketSTOMP) Read(p []byte) (int, error) { if len(w.readerBuffer) == 0 { _, msg, err := w.connection.Read(context.Background()) if err != nil { return 0, err } w.readerBuffer = msg } n := copy(p, w.readerBuffer) w.readerBuffer = w.readerBuffer[n:] return n, nil } // Write sends data to the websocket. // The data is buffered until a full STOMP frame is written, then sent in a WS message. func (w *WebsocketSTOMP) Write(p []byte) (int, error) { w.writeBuffer = append(w.writeBuffer, p...) // Send if we reach a null byte or the message is a single heartbeat (linefeed). if p[len(p)-1] == NullByte || (len(w.writeBuffer) == 1 && len(p) == 1 && p[0] == LineFeedByte) { err := w.sendMessage(context.Background()) if err != nil { return 0, err } } return len(p), nil } // sendMessage sends the accumulated writeBuffer data via the websocket. func (w *WebsocketSTOMP) sendMessage(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, writeTimeout) defer cancel() err := w.connection.Write(ctx, websocket.MessageText, w.writeBuffer) if err != nil { // Preserve the buffer in case of a write failure. return fmt.Errorf("failed to write message: %w", err) } w.writeBuffer = nil return nil } // Close closes the websocket connection with a normal closure status. func (w *WebsocketSTOMP) Close() error { return w.connection.Close(websocket.StatusNormalClosure, "terminating connection") } // Connect establishes a websocket connection with the provided URL. // The context is used only for the connection handshake. func Connect(ctx context.Context, url string, options *websocket.DialOptions) (*WebsocketSTOMP, error) { if options == nil { options = &websocket.DialOptions{} } if options.HTTPClient == nil { options.HTTPClient = &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { switch req.URL.Scheme { case "ws": req.URL.Scheme = "http" case "wss": req.URL.Scheme = "https" default: return fmt.Errorf("unexpected url scheme: %q", req.URL.Scheme) } return nil }, Timeout: 30 * time.Second, } } con, _, err := websocket.Dial(ctx, url, options) if err != nil { return nil, fmt.Errorf("failed to dial websocket: %w", err) } return &WebsocketSTOMP{ connection: con, }, nil }