HEX
Server: Apache/2
System: Linux nexus-01 4.18.0-553.120.1.el8_10.x86_64 #1 SMP Mon Apr 20 18:04:27 EDT 2026 x86_64
User: aglcoke (1118)
PHP: 8.2.31
Disabled: mail,exec,system,passthru,shell_exec,proc_close,proc_open,dl,popen,show_source,posix_kill,posix_mkfifo,posix_getpwuid,posix_setpgid,posix_setsid,posix_setuid,posix_setgid,posix_seteuid,posix_setegid,posix_uname
Upload Files
File: //proc/1/task/1/root/usr/share/rspamd/lualib/llm_context.lua
--[[
Context management for LLM-based spam detection

Provides:
  - fetch(task, redis_params, opts, callback, debug_module): load context JSON from Redis and format prompt snippet
  - update_after_classification(task, redis_params, opts, result, sel_part, debug_module): update context after LLM result

Opts (all optional, safe defaults applied):
  enabled: boolean
  level: 'user' | 'domain' | 'esld' (scope for context key)
  key_prefix: string (prefix before scope)
  key_suffix: string (suffix after identity)
  max_messages: number (sliding window size)
  message_ttl: seconds
  ttl: seconds (Redis key TTL)
  top_senders: number (how many to keep in top_senders)
  summary_max_chars: number (truncate stored text)
  flagged_phrases: array of strings (case-insensitive match)
  last_labels_count: number

debug_module: optional string, module name for debug logging (default: 'llm_context')
]]

local M = {}

local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
local rspamd_util = require "rspamd_util"
local llm_common = require "llm_common"

local EMPTY = {}

local DEFAULTS = {
  enabled = false,
  level = 'user',
  key_prefix = 'user',
  key_suffix = 'mail_context',
  max_messages = 40,
  min_messages = 5, -- minimum messages in context before injecting into prompt
  message_ttl = 14 * 24 * 3600,
  ttl = 30 * 24 * 3600,
  top_senders = 5,
  summary_max_chars = 512,
  flagged_phrases = {
    'reset your password',
    'click here to verify',
    'confirm your account',
    'urgent invoice',
    'wire transfer',
  },
  last_labels_count = 10,
}

local function to_seconds(v)
  if type(v) == 'number' then return v end
  return tonumber(v) or 0
end

local function get_domain_from_addr(addr)
  if not addr then return nil end
  return string.match(addr, '.*@(.+)')
end

-- Determine our user/domain - same identity for both incoming and outgoing mail
local function get_our_identity(task, scope)
  -- For outgoing mail: authenticated user or sender from local network
  -- For incoming mail: principal recipient
  local user = task:get_user()
  local ip = task:get_ip()
  local is_outgoing = user or (ip and ip:is_local())

  local identity
  if scope == 'user' then
    if is_outgoing then
      -- Outgoing: use sender (authenticated user or from address)
      identity = user or task:get_reply_sender()
      if not identity then
        local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
        identity = from
      end
    else
      -- Incoming: use recipient
      identity = task:get_principal_recipient()
    end
  elseif scope == 'domain' then
    if is_outgoing then
      -- Outgoing: domain of sender
      if user then
        identity = get_domain_from_addr(user)
      end
      if not identity then
        identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
      end
    else
      -- Incoming: domain of recipient
      local rcpt = task:get_principal_recipient()
      identity = get_domain_from_addr(rcpt)
    end
  elseif scope == 'esld' then
    if is_outgoing then
      -- Outgoing: eSLD of sender domain
      local d
      if user then
        d = get_domain_from_addr(user)
      end
      if not d then
        d = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
      end
      if d then identity = rspamd_util.get_tld(d) end
    else
      -- Incoming: eSLD of recipient domain
      local rcpt = task:get_principal_recipient()
      local d = get_domain_from_addr(rcpt)
      if d then
        identity = rspamd_util.get_tld(d)
      end
    end
  end

  return identity
end

local function compute_identity(task, opts, debug_module)
  local N = debug_module or 'llm_context'
  local scope = opts.level or DEFAULTS.level
  local identity = get_our_identity(task, scope)

  if not identity or identity == '' then
    return nil
  end

  -- Log direction for debugging
  local user = task:get_user()
  local ip = task:get_ip()
  local is_outgoing = user or (ip and ip:is_local())
  lua_util.debugm(N, task, 'computed identity for %s (%s): %s',
    scope, is_outgoing and 'outgoing' or 'incoming', tostring(identity))

  local key_prefix = opts.key_prefix or DEFAULTS.key_prefix
  local key_suffix = opts.key_suffix or DEFAULTS.key_suffix
  local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix)

  return {
    scope = scope,
    identity = identity,
    key = key,
  }
end

local function parse_json(data)
  if not data then return nil end
  -- Redis can return userdata nil or empty string
  if type(data) == 'userdata' then
    data = tostring(data)
  end
  if type(data) ~= 'string' or data == '' then
    return nil
  end
  local parser = ucl.parser()
  local ok, err = parser:parse_text(data)
  if not ok then return nil, err end
  return parser:get_object()
end

local function encode_json(obj)
  return ucl.to_format(obj, 'json-compact', true)
end

local function now()
  return os.time()
end

local function truncate_text(txt, limit)
  if not txt then return '' end
  if #txt <= limit then return txt end
  return txt:sub(1, limit)
end

local function has_flag(flags, flag_name)
  if type(flags) ~= 'table' then return false end
  for _, f in ipairs(flags) do
    if f == flag_name then return true end
  end
  return false
end

local function extract_keywords(text_part, limit)
  if not text_part then return {} end
  local words = text_part:get_words('full')
  if not words or #words == 0 then return {} end

  local counts = {}
  for _, w in ipairs(words) do
    local norm_word = w[2] or '' -- normalized
    local flags = w[4] or {}
    -- Skip stop words, too short, or non-text
    if not has_flag(flags, 'stop_word') and #norm_word > 2 and has_flag(flags, 'text') then
      counts[norm_word] = (counts[norm_word] or 0) + 1
    end
  end

  local arr = {}
  for word, cnt in pairs(counts) do
    table.insert(arr, { w = word, c = cnt })
  end
  table.sort(arr, function(a, b)
    if a.c == b.c then return a.w < b.w end
    return a.c > b.c
  end)

  local res = {}
  for i = 1, math.min(limit or 12, #arr) do
    table.insert(res, arr[i].w)
  end
  return res
end

local function safe_array(arr)
  if type(arr) ~= 'table' then return {} end
  return arr
end

local function build_message_summary(task, sel_part, opts)
  local model_cfg = { max_tokens = 256 }
  local content_tbl
  if sel_part then
    local itbl = llm_common.build_llm_input(task, {
      max_tokens = model_cfg.max_tokens,
      reply_trim_mode = opts.reply_trim_mode,
    })
    content_tbl = itbl
  else
    content_tbl = llm_common.build_llm_input(task, {
      max_tokens = model_cfg.max_tokens,
      reply_trim_mode = opts.reply_trim_mode,
    })
  end
  if type(content_tbl) ~= 'table' then
    return nil
  end
  local txt = content_tbl.text or ''
  local summary_max = opts.summary_max_chars or DEFAULTS.summary_max_chars
  local msg = {
    from = content_tbl.from or ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'],
    subject = content_tbl.subject or '',
    ts = now(),
    keywords = extract_keywords(sel_part, 12),
  }
  if txt and #txt > 0 then
    msg.text = truncate_text(txt, summary_max)
  end
  return msg
end

local function trim_messages(recent_messages, max_messages, min_ts)
  local res = {}
  for _, m in ipairs(recent_messages) do
    if not min_ts or (m.ts and m.ts >= min_ts) then
      table.insert(res, m)
    end
  end
  table.sort(res, function(a, b)
    local ta = a.ts or 0
    local tb = b.ts or 0
    return ta > tb
  end)
  while #res > max_messages do
    table.remove(res)
  end
  return res
end

local function recompute_top_senders(sender_counts, limit_n)
  local arr = {}
  for s, c in pairs(sender_counts or {}) do
    table.insert(arr, { s = s, c = c })
  end
  table.sort(arr, function(a, b)
    if a.c == b.c then return a.s < b.s end
    return a.c > b.c
  end)
  local res = {}
  for i = 1, math.min(limit_n, #arr) do
    table.insert(res, arr[i].s)
  end
  return res
end

local function ensure_defaults(ctx)
  if type(ctx) ~= 'table' then ctx = {} end
  ctx.recent_messages = safe_array(ctx.recent_messages)
  ctx.top_senders = safe_array(ctx.top_senders)
  ctx.flagged_phrases = safe_array(ctx.flagged_phrases)
  ctx.last_spam_labels = safe_array(ctx.last_spam_labels)
  ctx.sender_counts = ctx.sender_counts or {}
  return ctx
end

local function contains_ci(haystack, needle)
  if not haystack or not needle then return false end
  return string.find(string.lower(haystack), string.lower(needle), 1, true) ~= nil
end

local function update_flagged_phrases(ctx, text_part, opts)
  local phrases = opts.flagged_phrases or DEFAULTS.flagged_phrases
  if not text_part then return end
  local words = text_part:get_words('norm')
  if not words or #words == 0 then return end
  local text_lower = table.concat(words, ' ')
  for _, p in ipairs(phrases) do
    if contains_ci(text_lower, p) then
      local present = false
      for _, e in ipairs(ctx.flagged_phrases) do
        if string.lower(e) == string.lower(p) then
          present = true
          break
        end
      end
      if not present then
        table.insert(ctx.flagged_phrases, p)
      end
    end
  end
end

local function to_bullets_recent(recent_messages, limit_n)
  local lines = {}
  local n = math.min(limit_n, #recent_messages)
  for i = 1, n do
    local m = recent_messages[i]
    local from = m.from or m.sender or ''
    local subj = m.subject or ''
    table.insert(lines, string.format('- %s: %s', from, subj))
  end
  return table.concat(lines, '\n')
end

local function join_list(arr)
  if not arr or #arr == 0 then return '' end
  return table.concat(arr, ', ')
end

local function format_context_prompt(ctx, task)
  local bullets = to_bullets_recent(ctx.recent_messages or {}, 5)
  local top_senders = join_list(ctx.top_senders or {})
  local flagged = join_list(ctx.flagged_phrases or {})
  local spam_types = join_list(ctx.last_spam_labels or {})

  -- Check if current sender is known
  local sender_frequency = 'new'
  if task then
    local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
    if from and ctx.sender_counts and ctx.sender_counts[from] then
      local count = ctx.sender_counts[from]
      if count >= 10 then
        sender_frequency = 'frequent'
      elseif count >= 3 then
        sender_frequency = 'known'
      else
        sender_frequency = 'occasional'
      end
    end
  end

  local parts = {}
  table.insert(parts, 'User recent correspondence summary:')
  if bullets ~= '' then
    table.insert(parts, bullets)
  else
    table.insert(parts, '- (no recent messages)')
  end
  table.insert(parts, string.format('Top senders in mailbox: %s', top_senders))
  if flagged ~= '' then
    table.insert(parts, string.format('Recently flagged suspicious phrases: %s', flagged))
  end
  if spam_types ~= '' then
    table.insert(parts, string.format('Last detected spam types: %s', spam_types))
  end
  table.insert(parts, string.format('Current sender: %s', sender_frequency))

  return table.concat(parts, '\n')
end

function M.fetch(task, redis_params, opts, callback, debug_module)
  local N = debug_module or 'llm_context'
  opts = lua_util.override_defaults(DEFAULTS, opts or {})
  if not opts.enabled then
    callback(nil, nil, nil)
    return
  end
  if not redis_params then
    callback('no redis', nil, nil)
    return
  end

  local ident = compute_identity(task, opts, N)
  if not ident then
    lua_util.debugm(N, task, 'no identity computed, skipping context')
    callback('no identity', nil, nil)
    return
  end

  lua_util.debugm(N, task, 'fetching context for %s: %s',
    tostring(ident.scope), tostring(ident.identity))

  local function on_get(err, data)
    if err then
      rspamd_logger.errx(task, 'llm_context: get failed: %s', tostring(err))
      callback(err, nil, nil)
      return
    end
    local ctx
    if data then
      lua_util.debugm(N, task, 'got context data from redis, parsing')
      ctx = ensure_defaults(select(1, parse_json(data)) or {})
    else
      lua_util.debugm(N, task, 'no context data in redis, using empty')
      ctx = ensure_defaults({})
    end

    -- Check if context has enough messages for warm-up
    local min_msgs = opts.min_messages or DEFAULTS.min_messages
    local msg_count = #(ctx.recent_messages or {})
    if msg_count < min_msgs then
      lua_util.debugm(N, task, 'context has only %s messages (min: %s), not injecting into prompt',
        tostring(msg_count), tostring(min_msgs))
      callback(nil, ctx, nil) -- return ctx but no prompt snippet
      return
    end

    lua_util.debugm(N, task, 'context warm-up OK: %s messages, generating snippet',
      tostring(msg_count))
    local prompt_snippet = format_context_prompt(ctx, task)
    callback(nil, ctx, prompt_snippet)
  end

  local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
  if not ok then
    callback('request not scheduled', nil, nil)
  end
end

function M.update_after_classification(task, redis_params, opts, result, sel_part, debug_module)
  local N = debug_module or 'llm_context'
  opts = lua_util.override_defaults(DEFAULTS, opts or {})
  if not opts.enabled then return end
  if not redis_params then return end

  local ident = compute_identity(task, opts, N)
  if not ident then return end

  local function on_get(err, data)
    if err then
      rspamd_logger.errx(task, 'llm_context: get for update failed: %s', tostring(err))
      return
    end
    lua_util.debugm(N, task, 'updating context for %s: %s',
      tostring(ident.scope), tostring(ident.identity))
    local ctx = ensure_defaults(select(1, parse_json(data)) or {})

    local msg = build_message_summary(task, sel_part, opts)
    if msg then
      table.insert(ctx.recent_messages, 1, msg)
      local sender = msg.from or ''
      if sender ~= '' then
        ctx.sender_counts[sender] = (ctx.sender_counts[sender] or 0) + 1
      end
      update_flagged_phrases(ctx, sel_part, opts)
    end

    local min_ts = now() - to_seconds(opts.message_ttl)
    ctx.recent_messages = trim_messages(ctx.recent_messages, opts.max_messages, min_ts)
    ctx.top_senders = recompute_top_senders(ctx.sender_counts, opts.top_senders)

    local labels = {}
    if result then
      if result.categories and type(result.categories) == 'table' then
        for _, c in ipairs(result.categories) do table.insert(labels, tostring(c)) end
      end
      if result.probability then
        if result.probability > 0.5 then
          table.insert(labels, 'spam')
        else
          table.insert(labels, 'ham')
        end
      end
    end
    for _, l in ipairs(labels) do table.insert(ctx.last_spam_labels, 1, l) end
    while #ctx.last_spam_labels > opts.last_labels_count do table.remove(ctx.last_spam_labels) end

    ctx.updated_at = now()

    local payload = encode_json(ctx)
    local ttl = to_seconds(opts.ttl)
    local expire_at = now() + ttl

    -- Log what we're storing in context
    lua_util.debugm(N, task,
      'storing context for %s: %s messages, labels=%s, top_senders=%s, flagged=%s, payload_size=%s bytes, expiring at %s',
      tostring(ident.identity or '(none)'),
      tostring(#ctx.recent_messages),
      table.concat(ctx.last_spam_labels or {}, ','),
      table.concat(ctx.top_senders or {}, ','),
      table.concat(ctx.flagged_phrases or {}, ','),
      tostring(#payload),
      os.date('%Y-%m-%d %H:%M:%S', expire_at))

    if msg then
      lua_util.debugm(N, task,
        'added message: from=%s, subject=%s, keywords=%s',
        tostring(msg.from or '(none)'),
        tostring(msg.subject or '(none)'),
        table.concat(msg.keywords or {}, ','))
    end

    local function on_set(set_err)
      if set_err then
        rspamd_logger.errx(task, 'llm_context: set failed: %s', tostring(set_err))
      else
        lua_util.debugm(N, task, 'context saved to redis: key=%s, ttl=%s, expiring at %s',
          tostring(ident.key), tostring(ttl), os.date('%Y-%m-%d %H:%M:%S', expire_at))
      end
    end
    local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX',
      { ident.key, tostring(math.floor(ttl)), payload })
    if not ok then
      rspamd_logger.errx(task, 'llm_context: set request was not scheduled')
    end
  end

  local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
  if not ok then
    rspamd_logger.errx(task, 'llm_context: initial get request was not scheduled')
  end
end

return M