106 lines
2.7 KiB
Go
106 lines
2.7 KiB
Go
package wsstomp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
type WebsocketSTOMP struct {
|
|
connection *websocket.Conn
|
|
readerBuffer []byte
|
|
writeBuffer []byte
|
|
}
|
|
|
|
const (
|
|
NullByte = 0x00
|
|
LineFeedByte = 0x0a
|
|
writeTimeout = 10 * time.Second
|
|
)
|
|
|
|
// Read reads messages from the websocket connection until the provided array is full.
|
|
// Any surplus data is preserved for the next Read call.
|
|
func (w *WebsocketSTOMP) Read(p []byte) (int, error) {
|
|
if len(w.readerBuffer) == 0 {
|
|
_, msg, err := w.connection.Read(context.Background())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
w.readerBuffer = msg
|
|
}
|
|
|
|
n := copy(p, w.readerBuffer)
|
|
w.readerBuffer = w.readerBuffer[n:]
|
|
return n, nil
|
|
}
|
|
|
|
// Write sends data to the websocket.
|
|
// The data is buffered until a full STOMP frame is written, then sent in a WS message.
|
|
func (w *WebsocketSTOMP) Write(p []byte) (int, error) {
|
|
w.writeBuffer = append(w.writeBuffer, p...)
|
|
|
|
// Send if we reach a null byte or the message is a single heartbeat (linefeed).
|
|
if p[len(p)-1] == NullByte || (len(w.writeBuffer) == 1 && len(p) == 1 && p[0] == LineFeedByte) {
|
|
err := w.sendMessage(context.Background())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
// sendMessage sends the accumulated writeBuffer data via the websocket.
|
|
func (w *WebsocketSTOMP) sendMessage(ctx context.Context) error {
|
|
ctx, cancel := context.WithTimeout(ctx, writeTimeout)
|
|
defer cancel()
|
|
|
|
err := w.connection.Write(ctx, websocket.MessageText, w.writeBuffer)
|
|
if err != nil {
|
|
// Preserve the buffer in case of a write failure.
|
|
return fmt.Errorf("failed to write message: %w", err)
|
|
}
|
|
w.writeBuffer = nil
|
|
return nil
|
|
}
|
|
|
|
// Close closes the websocket connection with a normal closure status.
|
|
func (w *WebsocketSTOMP) Close() error {
|
|
return w.connection.Close(websocket.StatusNormalClosure, "terminating connection")
|
|
}
|
|
|
|
// Connect establishes a websocket connection with the provided URL.
|
|
// The context is used only for the connection handshake.
|
|
func Connect(ctx context.Context, url string, options *websocket.DialOptions) (*WebsocketSTOMP, error) {
|
|
if options == nil {
|
|
options = &websocket.DialOptions{}
|
|
}
|
|
if options.HTTPClient == nil {
|
|
options.HTTPClient = &http.Client{
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
switch req.URL.Scheme {
|
|
case "ws":
|
|
req.URL.Scheme = "http"
|
|
case "wss":
|
|
req.URL.Scheme = "https"
|
|
default:
|
|
return fmt.Errorf("unexpected url scheme: %q", req.URL.Scheme)
|
|
}
|
|
return nil
|
|
},
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
}
|
|
|
|
con, _, err := websocket.Dial(ctx, url, options)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to dial websocket: %w", err)
|
|
}
|
|
|
|
return &WebsocketSTOMP{
|
|
connection: con,
|
|
}, nil
|
|
}
|