realstomp-go/wsstomp.go

106 lines
2.7 KiB
Go

package wsstomp
import (
"context"
"fmt"
"net/http"
"time"
"github.com/coder/websocket"
)
type WebsocketSTOMP struct {
connection *websocket.Conn
readerBuffer []byte
writeBuffer []byte
}
const (
NullByte = 0x00
LineFeedByte = 0x0a
writeTimeout = 10 * time.Second
)
// Read reads messages from the websocket connection until the provided array is full.
// Any surplus data is preserved for the next Read call.
func (w *WebsocketSTOMP) Read(p []byte) (int, error) {
if len(w.readerBuffer) == 0 {
_, msg, err := w.connection.Read(context.Background())
if err != nil {
return 0, err
}
w.readerBuffer = msg
}
n := copy(p, w.readerBuffer)
w.readerBuffer = w.readerBuffer[n:]
return n, nil
}
// Write sends data to the websocket.
// The data is buffered until a full STOMP frame is written, then sent in a WS message.
func (w *WebsocketSTOMP) Write(p []byte) (int, error) {
w.writeBuffer = append(w.writeBuffer, p...)
// Send if we reach a null byte or the message is a single heartbeat (linefeed).
if p[len(p)-1] == NullByte || (len(w.writeBuffer) == 1 && len(p) == 1 && p[0] == LineFeedByte) {
err := w.sendMessage(context.Background())
if err != nil {
return 0, err
}
}
return len(p), nil
}
// sendMessage sends the accumulated writeBuffer data via the websocket.
func (w *WebsocketSTOMP) sendMessage(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, writeTimeout)
defer cancel()
err := w.connection.Write(ctx, websocket.MessageText, w.writeBuffer)
if err != nil {
// Preserve the buffer in case of a write failure.
return fmt.Errorf("failed to write message: %w", err)
}
w.writeBuffer = nil
return nil
}
// Close closes the websocket connection with a normal closure status.
func (w *WebsocketSTOMP) Close() error {
return w.connection.Close(websocket.StatusNormalClosure, "terminating connection")
}
// Connect establishes a websocket connection with the provided URL.
// The context is used only for the connection handshake.
func Connect(ctx context.Context, url string, options *websocket.DialOptions) (*WebsocketSTOMP, error) {
if options == nil {
options = &websocket.DialOptions{}
}
if options.HTTPClient == nil {
options.HTTPClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
default:
return fmt.Errorf("unexpected url scheme: %q", req.URL.Scheme)
}
return nil
},
Timeout: 30 * time.Second,
}
}
con, _, err := websocket.Dial(ctx, url, options)
if err != nil {
return nil, fmt.Errorf("failed to dial websocket: %w", err)
}
return &WebsocketSTOMP{
connection: con,
}, nil
}