HEX
Server: Apache/2
System: Linux nexus-01 4.18.0-553.120.1.el8_10.x86_64 #1 SMP Mon Apr 20 18:04:27 EDT 2026 x86_64
User: aglcoke (1118)
PHP: 8.2.31
Disabled: mail,exec,system,passthru,shell_exec,proc_close,proc_open,dl,popen,show_source,posix_kill,posix_mkfifo,posix_getpwuid,posix_setpgid,posix_setsid,posix_setuid,posix_setgid,posix_seteuid,posix_setegid,posix_uname
Upload Files
File: //usr/share/rspamd/lualib/lua_neural_external.lua
--[[
Copyright (c) 2026, Vsevolod Stakhov <vsevolod@rspamd.com>

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--

--[[
External neural model loading and merging.

This module provides functionality to load pretrained neural models
from external sources (HTTP/HTTPS) via the Maps infrastructure
and merge them with locally trained weights.

Model format (msgpack):
{
  magic = "RNM1",           -- Rspamd Neural Model v1
  version = 1,              -- format version
  model_version = 123,      -- model training version (incremented on retrain)
  providers_digest = "...", -- digest of providers config (must match local)
  ann_data = "...",         -- serialized KANN (zstd compressed)
  pca_data = "...",         -- optional PCA (zstd compressed)
  norm_stats = {...},       -- normalization stats
  roc_thresholds = {...},   -- ROC thresholds
  created_at = timestamp,
}

Usage in neural config:
  external_model = {
    url = "https://your-provider.com/models/<digest>";
    sign_key = "your_key";  -- optional signature verification
    merge_alpha = 0.6;      -- 60% external, 40% local
  };
]]--

local lua_redis = require "lua_redis"
local rspamd_kann = require "rspamd_kann"
local rspamd_logger = require "rspamd_logger"
local rspamd_util = require "rspamd_util"
local rspamd_text = require "rspamd_text"
local ucl = require "ucl"

-- Model format constants
local MODEL_MAGIC = "RNM1"
local MODEL_FORMAT_VERSION = 1

local exports = {}

-- Cache of loaded external models: url -> { model, map, callbacks }
local external_model_cache = {}

--- Parse external model from msgpack data
-- @param data raw msgpack data (rspamd_text, possibly zstd compressed)
-- @return table with model data or nil, error message
function exports.parse_model(data)
  if not data then
    return nil, "no data"
  end

  -- Convert rspamd_text to string if needed
  local data_str
  if type(data) == 'userdata' or (type(data) == 'table' and data.cookie) then
    data_str = tostring(data)
  else
    data_str = data
  end

  -- Try zstd decompression first
  local decompressed
  local err, decompressed_data = rspamd_util.zstd_decompress(data_str)
  if not err and decompressed_data then
    decompressed = tostring(decompressed_data)
  else
    -- Assume uncompressed
    decompressed = data_str
  end

  -- Parse msgpack
  local parser = ucl.parser()
  local ok, parse_err = parser:parse_text(decompressed, 'msgpack')
  if not ok then
    return nil, "failed to parse msgpack: " .. (parse_err or "unknown error")
  end

  local model = parser:get_object()

  -- Validate model format
  if model.magic ~= MODEL_MAGIC then
    return nil, string.format("invalid magic: expected %s, got %s",
      MODEL_MAGIC, model.magic or "nil")
  end

  if model.version ~= MODEL_FORMAT_VERSION then
    return nil, string.format("unsupported model version: %s (expected %s)",
      model.version or "nil", MODEL_FORMAT_VERSION)
  end

  return model
end

--- Load KANN from model data
-- @param model parsed model table
-- @return kann_t object or nil, error
function exports.load_ann(model)
  if not model.ann_data then
    return nil, "no ann_data in model"
  end

  -- Decompress ann_data
  local ann_data_str = model.ann_data
  if type(ann_data_str) == 'userdata' or (type(ann_data_str) == 'table' and ann_data_str.cookie) then
    ann_data_str = tostring(ann_data_str)
  end

  local err, ann_data = rspamd_util.zstd_decompress(ann_data_str)
  if err then
    return nil, "failed to decompress ann_data: " .. err
  end

  local ann = rspamd_kann.load(ann_data)
  if not ann then
    return nil, "failed to load KANN from model"
  end

  return ann
end

--- Load PCA from model data
-- @param model parsed model table
-- @return tensor or nil
function exports.load_pca(model)
  if not model.pca_data then
    return nil
  end

  local pca_data_str = model.pca_data
  if type(pca_data_str) == 'userdata' or (type(pca_data_str) == 'table' and pca_data_str.cookie) then
    pca_data_str = tostring(pca_data_str)
  end

  local err, pca_data = rspamd_util.zstd_decompress(pca_data_str)
  if err then
    rspamd_logger.warnx(rspamd_config, "failed to decompress pca_data: %s", err)
    return nil
  end

  local rspamd_tensor = require "rspamd_tensor"
  return rspamd_tensor.load(pca_data)
end

--- Check if model is compatible with local config
-- @param model parsed model table
-- @param providers_digest local providers digest
-- @return boolean, reason
function exports.is_compatible(model, providers_digest)
  if not model.providers_digest then
    return false, "model has no providers_digest"
  end

  if model.providers_digest ~= providers_digest then
    return false, string.format("providers digest mismatch: model=%s, local=%s",
      model.providers_digest:sub(1, 8), providers_digest:sub(1, 8))
  end

  return true, "compatible"
end

--- Merge weights from external ANN into local ANN using interpolation
-- w_new = alpha * w_external + (1-alpha) * w_local
-- @param external_ann kann_t from external model
-- @param local_ann kann_t from local training
-- @param alpha weight for external (0.0 - 1.0)
-- @return merged kann_t or nil, error
function exports.merge_weights(external_ann, local_ann, alpha)
  if not external_ann or not local_ann then
    return nil, "missing ann"
  end

  alpha = alpha or 0.5

  -- Check compatibility first
  local ok = external_ann:is_compatible(local_ann)
  if not ok then
    return nil, "incompatible ANN architectures"
  end

  -- Use the external ANN as base and merge local weights into it
  local merged, err = external_ann:merge_weights(local_ann, 1.0 - alpha)

  if not merged then
    return nil, "merge failed: " .. (err or "unknown")
  end

  return external_ann
end

--- Build URL for external model based on providers digest
-- @param base_url base URL
-- @param providers_digest digest of providers config
-- @return full URL
function exports.build_model_url(base_url, providers_digest)
  -- Remove trailing slash from base_url
  base_url = base_url:gsub('/+$', '')
  return string.format("%s/%s", base_url, providers_digest)
end

--- Register external model as a map
-- This uses the Maps infrastructure for HTTP loading with signature verification
-- @param cfg rspamd_config
-- @param rule neural rule configuration
-- @param providers_digest digest of providers config
-- @param on_load_callback function(model_data, err) called when model is loaded/reloaded
-- @return boolean success
function exports.register_model_map(cfg, rule, providers_digest, on_load_callback)
  local ext_cfg = rule.external_model
  if not ext_cfg or not ext_cfg.url then
    return false
  end

  local url = ext_cfg.url
  local cache_key = url

  -- Check if already registered
  if external_model_cache[cache_key] then
    return true
  end

  -- Map callback: called when map data is loaded
  local function map_callback(data, map)
    if not data then
      rspamd_logger.errx(cfg, 'external neural model map returned no data for %s', url)
      if on_load_callback then
        on_load_callback(nil, "no data from map")
      end
      return
    end

    -- Parse model
    local model, parse_err = exports.parse_model(data)
    if not model then
      rspamd_logger.errx(cfg, 'failed to parse external neural model from %s: %s', url, parse_err)
      if on_load_callback then
        on_load_callback(nil, parse_err)
      end
      return
    end

    -- Check compatibility
    local compatible, reason = exports.is_compatible(model, providers_digest)
    if not compatible then
      rspamd_logger.errx(cfg, 'external neural model incompatible: %s', reason)
      if on_load_callback then
        on_load_callback(nil, reason)
      end
      return
    end

    rspamd_logger.infox(cfg, 'loaded external neural model from %s (version=%s)',
      url, model.model_version or 0)

    -- Update cache
    external_model_cache[cache_key] = {
      model = model,
      last_version = model.model_version,
      last_load = os.time(),
    }

    -- Call user callback
    if on_load_callback then
      on_load_callback(model, nil)
    end
  end

  -- Create callback map
  local map = cfg:add_map({
    url = url,
    type = 'callback',
    description = string.format('External neural model for rule %s', rule.prefix or 'default'),
    callback = map_callback,
    opaque_data = true,  -- Get data as rspamd_text
  })

  if not map then
    rspamd_logger.errx(cfg, 'failed to register external neural model map for %s', url)
    return false
  end

  -- Set sign key if configured
  if ext_cfg.sign_key then
    map:set_sign_key(ext_cfg.sign_key)
  end

  external_model_cache[cache_key] = {
    map = map,
    callbacks = { on_load_callback },
  }

  return true
end

--- Get cached model data for URL
-- @param url model URL
-- @return cached model data or nil
function exports.get_cached_model(url)
  local cached = external_model_cache[url]
  if cached and cached.model then
    return cached.model
  end
  return nil
end

--- Create external model configuration for a neural rule
-- @param rule neural rule configuration
-- @param providers_digest digest of providers config
-- @return table with external model config or nil
function exports.create_external_config(rule, providers_digest)
  local ext = rule.external_model
  if not ext then
    return nil
  end

  local url = ext.url
  if not url then
    -- Build URL from digest if base_url is provided
    if ext.base_url then
      url = exports.build_model_url(ext.base_url, providers_digest)
    else
      rspamd_logger.errx(rspamd_config, 'external_model requires url or base_url')
      return nil
    end
  end

  return {
    url = url,
    sign_key = ext.sign_key,
    merge_strategy = ext.merge_strategy or "interpolate",
    merge_alpha = ext.merge_alpha or 0.5,
    check_interval = ext.check_interval or 86400, -- 24h
    local_fine_tune = ext.local_fine_tune ~= false,
    min_local_samples = ext.min_local_samples or 50,
    providers_digest = providers_digest,
    loaded = false,
    last_version = nil,
    last_check = nil,
  }
end

--- Store external model metadata in Redis for later merge
-- @param redis redis params
-- @param ev_base event base
-- @param ann_key Redis key for the ANN
-- @param model_data parsed model data
-- @param callback function(err)
function exports.store_base_model(redis, ev_base, ann_key, model_data, callback)
  -- Store base model version and compressed ann_data for re-merge
  local base_key = ann_key .. "_base"

  local function store_cb(err)
    if err then
      rspamd_logger.errx(rspamd_config, "failed to store base model: %s", err)
    end
    if callback then
      callback(err)
    end
  end

  -- Ensure ann_data is rspamd_text for opaque storage
  local ann_data = model_data.ann_data
  if type(ann_data) == 'string' then
    -- Already compressed, convert to text
    ann_data = rspamd_text.fromstring(ann_data)
  end

  -- Store base version and ann_data
  lua_redis.redis_make_request_taskless(ev_base,
    rspamd_config,
    redis,
    nil,
    true, -- is write
    store_cb,
    'HMSET',
    {
      base_key,
      'version', tostring(model_data.model_version or 0),
      'ann_data', ann_data,
      'providers_digest', model_data.providers_digest or '',
      'created_at', tostring(model_data.created_at or os.time()),
    },
    { opaque_data = true }
  )
end

--- Load base model from Redis for re-merge
-- @param redis redis params
-- @param ev_base event base
-- @param ann_key Redis key for the ANN
-- @param callback function(err, model_data)
function exports.load_base_model(redis, ev_base, ann_key, callback)
  local base_key = ann_key .. "_base"

  local function load_cb(err, data)
    if err then
      callback(err, nil)
      return
    end

    if type(data) ~= 'table' then
      callback("no base model found", nil)
      return
    end

    local model_data = {
      model_version = tonumber(data[1]) or 0,
      ann_data = data[2], -- rspamd_text
      providers_digest = data[3],
      created_at = tonumber(data[4]) or 0,
    }

    callback(nil, model_data)
  end

  lua_redis.redis_make_request_taskless(ev_base,
    rspamd_config,
    redis,
    nil,
    false, -- is write
    load_cb,
    'HMGET',
    { base_key, 'version', 'ann_data', 'providers_digest', 'created_at' },
    { opaque_data = true }
  )
end

--- Serialize model to msgpack format
-- @param ann kann_t object
-- @param pca optional PCA tensor
-- @param providers_digest providers config digest
-- @param opts optional { norm_stats, roc_thresholds }
-- @return rspamd_text with compressed msgpack data
function exports.serialize_model(ann, pca, providers_digest, opts)
  opts = opts or {}

  -- Save ANN to memory
  local ann_text = ann:save()
  local ann_compressed = rspamd_util.zstd_compress(ann_text)

  local model = {
    magic = MODEL_MAGIC,
    version = MODEL_FORMAT_VERSION,
    model_version = opts.model_version or 1,
    providers_digest = providers_digest,
    ann_data = ann_compressed,
    created_at = os.time(),
  }

  if pca then
    local pca_text = pca:save()
    model.pca_data = rspamd_util.zstd_compress(pca_text)
  end

  if opts.norm_stats then
    model.norm_stats = opts.norm_stats
  end

  if opts.roc_thresholds then
    model.roc_thresholds = opts.roc_thresholds
  end

  return ucl.to_format(model, 'msgpack')
end

exports.MODEL_MAGIC = MODEL_MAGIC
exports.MODEL_FORMAT_VERSION = MODEL_FORMAT_VERSION

return exports