QOL & Security improvement

Signed-off-by: Hadi <112569860+anotherhadi@users.noreply.github.com>
This commit is contained in:
Hadi
2026-05-19 10:09:42 +02:00
parent 03260e0947
commit a147e8b972
12 changed files with 160 additions and 154 deletions
+9 -2
View File
@@ -2,12 +2,14 @@ package db
import (
"database/sql"
"sync"
_ "modernc.org/sqlite"
)
type DB struct {
conn *sql.DB
conn *sql.DB
dedupMu sync.Mutex
}
func Open(path string) (*DB, error) {
@@ -33,7 +35,8 @@ func (d *DB) migrate() error {
path TEXT NOT NULL,
status_code INTEGER NOT NULL,
request_raw TEXT NOT NULL,
response_raw TEXT NOT NULL
response_raw TEXT NOT NULL,
body_hash TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS replay_entries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -65,6 +68,10 @@ CREATE TABLE IF NOT EXISTS replay_entries (
UNIQUE(plugin_name, dedup_key)
);
`)
if err != nil {
return err
}
_, err = d.conn.Exec(`CREATE INDEX IF NOT EXISTS idx_entries_dedup ON entries(method, host, path, body_hash)`)
return err
}
+40 -40
View File
@@ -1,6 +1,7 @@
package db
import (
"crypto/sha256"
"database/sql"
"fmt"
"strings"
@@ -18,40 +19,45 @@ type Entry struct {
ResponseRaw string
}
// HasDuplicate returns true if an entry with the same method, host, path and
// request body already exists. Used to implement skip_duplicates filtering.
func (d *DB) HasDuplicate(method, host, path, body string) (bool, error) {
rows, err := d.conn.Query(
`SELECT request_raw FROM entries WHERE method = ? AND host = ? AND path = ?`,
method, host, path,
)
if err != nil {
return false, err
}
defer rows.Close()
for rows.Next() {
var raw string
if err := rows.Scan(&raw); err != nil {
return false, err
}
parts := strings.SplitN(raw, "\n\n", 2)
entryBody := ""
if len(parts) == 2 {
entryBody = parts[1]
}
if entryBody == body {
return true, nil
}
}
return false, rows.Err()
func bodyHash(body string) string {
sum := sha256.Sum256([]byte(body))
return fmt.Sprintf("%x", sum)
}
func (d *DB) InsertEntry(e Entry) (Entry, error) {
// HasDuplicate returns true if an entry with the same method, host, path and
// request body hash already exists.
func (d *DB) HasDuplicate(method, host, path, body string) (bool, error) {
hash := bodyHash(body)
var exists int
err := d.conn.QueryRow(
`SELECT 1 FROM entries WHERE method = ? AND host = ? AND path = ? AND body_hash = ? LIMIT 1`,
method, host, path, hash,
).Scan(&exists)
if err == sql.ErrNoRows {
return false, nil
}
return err == nil, err
}
// InsertIfNotDuplicate atomically checks for a duplicate and inserts if none
// exists. Returns (entry, isDuplicate, error).
func (d *DB) InsertIfNotDuplicate(e Entry, body string) (Entry, bool, error) {
d.dedupMu.Lock()
defer d.dedupMu.Unlock()
dup, err := d.HasDuplicate(e.Method, e.Host, e.Path, body)
if err != nil || dup {
return e, dup, err
}
e, err = d.InsertEntry(e, body)
return e, false, err
}
func (d *DB) InsertEntry(e Entry, body string) (Entry, error) {
res, err := d.conn.Exec(
`INSERT INTO entries (timestamp, method, host, path, status_code, request_raw, response_raw)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
`INSERT INTO entries (timestamp, method, host, path, status_code, request_raw, response_raw, body_hash)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
e.Timestamp.UTC().Format(time.RFC3339),
e.Method, e.Host, e.Path, e.StatusCode, e.RequestRaw, e.ResponseRaw,
e.Method, e.Host, e.Path, e.StatusCode, e.RequestRaw, e.ResponseRaw, bodyHash(body),
)
if err != nil {
return e, err
@@ -102,16 +108,10 @@ func (d *DB) SearchEntries(term string) ([]Entry, error) {
return scanEntries(rows)
}
// QueryEntries executes a user-supplied query against the entries table.
// If the query does not start with SELECT, it is treated as a WHERE expression
// and wrapped automatically (e.g. "status_code = 404" becomes a full SELECT).
func (d *DB) QueryEntries(rawSQL string) ([]Entry, error) {
q := strings.TrimSpace(rawSQL)
if !strings.HasPrefix(strings.ToUpper(q), "SELECT") {
q = "SELECT id, timestamp, method, host, path, status_code, request_raw, response_raw FROM entries WHERE " + q
} else if strings.ContainsAny(strings.ToUpper(q), "INSERTDELETEUPDATEDROP") {
return nil, fmt.Errorf("only SELECT queries are allowed")
}
// QueryEntries runs a WHERE expression supplied by the user against the entries
// table (e.g. "status_code = 404" or "host LIKE '%example.com%'").
func (d *DB) QueryEntries(where string) ([]Entry, error) {
q := "SELECT id, timestamp, method, host, path, status_code, request_raw, response_raw FROM entries WHERE " + strings.TrimSpace(where)
rows, err := d.conn.Query(q)
if err != nil {
return nil, err