mirror of
https://github.com/anotherhadi/spilltea.git
synced 2026-05-20 01:32:33 +02:00
QOL & Security improvement
Signed-off-by: Hadi <112569860+anotherhadi@users.noreply.github.com>
This commit is contained in:
+9
-2
@@ -2,12 +2,14 @@ package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
conn *sql.DB
|
||||
conn *sql.DB
|
||||
dedupMu sync.Mutex
|
||||
}
|
||||
|
||||
func Open(path string) (*DB, error) {
|
||||
@@ -33,7 +35,8 @@ func (d *DB) migrate() error {
|
||||
path TEXT NOT NULL,
|
||||
status_code INTEGER NOT NULL,
|
||||
request_raw TEXT NOT NULL,
|
||||
response_raw TEXT NOT NULL
|
||||
response_raw TEXT NOT NULL,
|
||||
body_hash TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS replay_entries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -65,6 +68,10 @@ CREATE TABLE IF NOT EXISTS replay_entries (
|
||||
UNIQUE(plugin_name, dedup_key)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = d.conn.Exec(`CREATE INDEX IF NOT EXISTS idx_entries_dedup ON entries(method, host, path, body_hash)`)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
+40
-40
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -18,40 +19,45 @@ type Entry struct {
|
||||
ResponseRaw string
|
||||
}
|
||||
|
||||
// HasDuplicate returns true if an entry with the same method, host, path and
|
||||
// request body already exists. Used to implement skip_duplicates filtering.
|
||||
func (d *DB) HasDuplicate(method, host, path, body string) (bool, error) {
|
||||
rows, err := d.conn.Query(
|
||||
`SELECT request_raw FROM entries WHERE method = ? AND host = ? AND path = ?`,
|
||||
method, host, path,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var raw string
|
||||
if err := rows.Scan(&raw); err != nil {
|
||||
return false, err
|
||||
}
|
||||
parts := strings.SplitN(raw, "\n\n", 2)
|
||||
entryBody := ""
|
||||
if len(parts) == 2 {
|
||||
entryBody = parts[1]
|
||||
}
|
||||
if entryBody == body {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, rows.Err()
|
||||
func bodyHash(body string) string {
|
||||
sum := sha256.Sum256([]byte(body))
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
|
||||
func (d *DB) InsertEntry(e Entry) (Entry, error) {
|
||||
// HasDuplicate returns true if an entry with the same method, host, path and
|
||||
// request body hash already exists.
|
||||
func (d *DB) HasDuplicate(method, host, path, body string) (bool, error) {
|
||||
hash := bodyHash(body)
|
||||
var exists int
|
||||
err := d.conn.QueryRow(
|
||||
`SELECT 1 FROM entries WHERE method = ? AND host = ? AND path = ? AND body_hash = ? LIMIT 1`,
|
||||
method, host, path, hash,
|
||||
).Scan(&exists)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
// InsertIfNotDuplicate atomically checks for a duplicate and inserts if none
|
||||
// exists. Returns (entry, isDuplicate, error).
|
||||
func (d *DB) InsertIfNotDuplicate(e Entry, body string) (Entry, bool, error) {
|
||||
d.dedupMu.Lock()
|
||||
defer d.dedupMu.Unlock()
|
||||
dup, err := d.HasDuplicate(e.Method, e.Host, e.Path, body)
|
||||
if err != nil || dup {
|
||||
return e, dup, err
|
||||
}
|
||||
e, err = d.InsertEntry(e, body)
|
||||
return e, false, err
|
||||
}
|
||||
|
||||
func (d *DB) InsertEntry(e Entry, body string) (Entry, error) {
|
||||
res, err := d.conn.Exec(
|
||||
`INSERT INTO entries (timestamp, method, host, path, status_code, request_raw, response_raw)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
`INSERT INTO entries (timestamp, method, host, path, status_code, request_raw, response_raw, body_hash)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
e.Timestamp.UTC().Format(time.RFC3339),
|
||||
e.Method, e.Host, e.Path, e.StatusCode, e.RequestRaw, e.ResponseRaw,
|
||||
e.Method, e.Host, e.Path, e.StatusCode, e.RequestRaw, e.ResponseRaw, bodyHash(body),
|
||||
)
|
||||
if err != nil {
|
||||
return e, err
|
||||
@@ -102,16 +108,10 @@ func (d *DB) SearchEntries(term string) ([]Entry, error) {
|
||||
return scanEntries(rows)
|
||||
}
|
||||
|
||||
// QueryEntries executes a user-supplied query against the entries table.
|
||||
// If the query does not start with SELECT, it is treated as a WHERE expression
|
||||
// and wrapped automatically (e.g. "status_code = 404" becomes a full SELECT).
|
||||
func (d *DB) QueryEntries(rawSQL string) ([]Entry, error) {
|
||||
q := strings.TrimSpace(rawSQL)
|
||||
if !strings.HasPrefix(strings.ToUpper(q), "SELECT") {
|
||||
q = "SELECT id, timestamp, method, host, path, status_code, request_raw, response_raw FROM entries WHERE " + q
|
||||
} else if strings.ContainsAny(strings.ToUpper(q), "INSERTDELETEUPDATEDROP") {
|
||||
return nil, fmt.Errorf("only SELECT queries are allowed")
|
||||
}
|
||||
// QueryEntries runs a WHERE expression supplied by the user against the entries
|
||||
// table (e.g. "status_code = 404" or "host LIKE '%example.com%'").
|
||||
func (d *DB) QueryEntries(where string) ([]Entry, error) {
|
||||
q := "SELECT id, timestamp, method, host, path, status_code, request_raw, response_raw FROM entries WHERE " + strings.TrimSpace(where)
|
||||
rows, err := d.conn.Query(q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user