File: //proc/thread-self/root/proc/self/root/proc/thread-self/root/usr/share/rspamd/plugins/neural.lua
--[[
Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]] --
local fun = require "fun"
local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local lua_verdict = require "lua_verdict"
local neural_common = require "plugins/neural"
local neural_learn = require "lua_neural_learn"
local neural_external = require "lua_neural_external"
local rspamd_kann = require "rspamd_kann"
local rspamd_logger = require "rspamd_logger"
local rspamd_tensor = require "rspamd_tensor"
local rspamd_text = require "rspamd_text"
local rspamd_util = require "rspamd_util"
local T = require "lua_shape.core"
local PluginSchema = require "lua_shape.plugin_schema"
-- Load providers
pcall(require, "plugins/neural/providers/llm")
pcall(require, "plugins/neural/providers/symbols")
pcall(require, "plugins/neural/providers/text_hash")
pcall(require, "plugins/neural/providers/fasttext_embed")
local N = "neural"
local settings = neural_common.settings
local redis_profile_schema = T.table({
digest = T.string():doc({ summary = "Symbols digest" }),
symbols = T.array(T.string()):doc({ summary = "List of symbols" }),
version = T.number():doc({ summary = "Profile version" }),
redis_key = T.string():doc({ summary = "Redis key for ANN" }),
distance = T.number():optional():doc({ summary = "Distance metric" }),
providers_digest = T.string():optional():doc({ summary = "Providers digest" }),
}):doc({ summary = "Neural network profile schema" })
PluginSchema.register("plugins.neural.profile", redis_profile_schema)
if confighelp then
return
end
local has_blas = rspamd_tensor.has_blas()
local text_cookie = rspamd_text.cookie
-- Forward declarations
local maybe_carryover_ann
local load_ann_profile
-- Creates and stores ANN profile in Redis
local function new_ann_profile(task, rule, set, version)
local ann_key = neural_common.new_ann_key(rule, set, version, settings)
local providers_digest = neural_common.providers_config_digest(rule.providers, rule)
local profile = {
symbols = set.symbols,
redis_key = ann_key,
version = version,
digest = set.digest,
distance = 0, -- Since we are using our own profile
providers_digest = providers_digest,
}
local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
local function add_cb(err, _)
if err then
rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
rule.prefix, set.name, profile.redis_key, err)
else
rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
rule.prefix, set.name, profile.redis_key)
-- Carry weights from a prior profile (same providers_digest, different
-- symbol-list digest) into the fresh profile key ONLY when the input
-- vector schema is decided entirely by providers -- i.e. when
-- disable_symbols_input is set. In hybrid mode (providers + symbols)
-- the symbol portion of the vector reshapes with symbol drift, and
-- load_new_ann then sets set.ann.symbols = profile.symbols (= current
-- symbol list), so copied weights would be indexed against features
-- they were never trained against -- silent garbage at inference.
-- For hybrid mode is_profile_compatible already routes inference to
-- the prior profile entry, which carries its own (older) symbol list
-- and therefore keeps weights correctly aligned at inference time;
-- skipping carryover is the right behaviour.
if providers_digest and rule.disable_symbols_input then
maybe_carryover_ann(task, rule, set, ann_key, providers_digest)
end
end
end
lua_redis.redis_make_request(task,
rule.redis,
nil,
true, -- is write
add_cb, --callback
'ZADD', -- command
{ set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
return profile
end
-- ANN filter function, used to insert scores based on the existing symbols
local function ann_scores_filter(task)
for _, rule in pairs(settings.rules) do
local sid = task:get_settings_id() or -1
local ann
local profile
local set = neural_common.get_rule_settings(task, rule)
if set then
if set.ann then
ann = set.ann.ann
profile = set.ann
else
lua_util.debugm(N, task, 'no ann loaded for %s:%s',
rule.prefix, set.name)
end
else
lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
rule.prefix, sid)
end
if ann then
local function after_features(vec, meta)
-- For providers-based ANNs, require matching digest
-- For symbols-based ANNs (no providers), skip this check
local has_providers = rule.providers and #rule.providers > 0
if has_providers then
local stored_digest = profile.providers_digest
local current_digest = meta and meta.digest
if not stored_digest then
-- Old ANN was trained without providers - needs retraining with current config
lua_util.debugm(N, task,
'ANN %s:%s was trained without providers, skipping (retrain with current config)',
rule.prefix, set.name)
vec = nil
elseif stored_digest ~= current_digest then
rspamd_logger.warnx(task,
'providers config changed for %s:%s (stored=%s, current=%s), ANN needs retraining',
rule.prefix, set.name, stored_digest, current_digest or 'none')
vec = nil
end
end
local score
if not vec then
return
end
if set.ann.norm_stats then
vec = neural_common.apply_normalization(vec, set.ann.norm_stats)
end
local out = ann:apply1(vec, set.ann.pca)
score = out[1]
local symscore = string.format('%.3f', score)
task:cache_set(rule.prefix .. '_neural_score', score)
lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
rule.prefix, set.name, set.ann.version, symscore)
if score > 0 then
local result = score
-- If spam_score_threshold is defined, override all other thresholds.
local spam_threshold = 0
if rule.spam_score_threshold then
spam_threshold = rule.spam_score_threshold
elseif rule.roc_enabled and set.ann.roc_thresholds then
spam_threshold = set.ann.roc_thresholds[1]
end
if result >= spam_threshold then
if rule.flat_threshold_curve then
task:insert_result(rule.symbol_spam, 1.0, symscore)
else
task:insert_result(rule.symbol_spam, result, symscore)
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
rule.prefix, set.name, set.ann.version, symscore,
spam_threshold)
end
else
local result = -(score)
-- If ham_score_threshold is defined, override all other thresholds.
local ham_threshold = 0
if rule.ham_score_threshold then
ham_threshold = rule.ham_score_threshold
elseif rule.roc_enabled and set.ann.roc_thresholds then
ham_threshold = set.ann.roc_thresholds[2]
end
if result >= ham_threshold then
if rule.flat_threshold_curve then
task:insert_result(rule.symbol_ham, 1.0, symscore)
else
task:insert_result(rule.symbol_ham, result, symscore)
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
rule.prefix, set.name, set.ann.version, result,
ham_threshold)
end
end
end
if rule.providers and #rule.providers > 0 then
neural_common.collect_features_async(task, rule, profile, 'infer', after_features)
else
local vec = neural_common.result_to_vector(task, profile)
after_features(vec)
end
end
end
end
local function get_ann_train_header(task)
local hdr = task:get_request_header('ANN-Train')
if type(hdr) == 'table' then
hdr = hdr[1]
end
if hdr then
return tostring(hdr):lower()
end
return nil
end
local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train
local learn_spam, learn_ham
local skip_reason = 'unknown'
local manual_train = false
-- First, honor explicit manual training header if present
do
local hv = get_ann_train_header(task)
if hv then
lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode: %s', hv)
if hv == 'spam' then
learn_spam = true
manual_train = true
elseif hv == 'ham' then
learn_ham = true
manual_train = true
else
skip_reason = 'no explicit header'
end
end
end
-- Check for autolearn class set by mempool (integration with external learning decisions)
if not manual_train then
local autolearn_class = neural_learn.get_autolearn_class(task)
if autolearn_class then
lua_util.debugm(N, task, 'found neural autolearn class in mempool: %s', autolearn_class)
if autolearn_class == 'spam' then
learn_spam = true
manual_train = true
elseif autolearn_class == 'ham' then
learn_ham = true
manual_train = true
end
end
end
-- Check which providers are configured
local has_llm_provider = false
local has_symbols_provider = false
if rule.providers and #rule.providers > 0 then
for _, p in ipairs(rule.providers) do
if p.type == 'llm' then
has_llm_provider = true
elseif p.type == 'symbols' then
has_symbols_provider = true
end
end
else
-- No providers configured = implicit symbols-only mode
has_symbols_provider = true
end
if has_llm_provider and not manual_train then
-- Use expression-based autolearn conditions for LLM providers
if rule.autolearn and rule.autolearn.enabled then
local learn_type, reason = neural_learn.get_learn_type(task, rule)
if learn_type == 'spam' then
learn_spam = true
lua_util.debugm(N, task, 'autolearn spam via expression: %s', reason)
elseif learn_type == 'ham' then
learn_ham = true
lua_util.debugm(N, task, 'autolearn ham via expression: %s', reason)
else
skip_reason = reason or 'autolearn condition not met'
lua_util.debugm(N, task, 'autolearn skip: %s', skip_reason)
end
else
-- LLM provider without autolearn config - require manual training
learn_spam = false
learn_ham = false
skip_reason = 'llm provider requires autolearn config or manual training'
lua_util.debugm(N, task, 'suppress autotrain: llm provider present, no autolearn config')
end
elseif not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
-- Traditional score/verdict based learning for non-LLM providers
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score
if not learn_spam then
skip_reason = string.format('score < spam_score: %f < %f',
score, train_opts.spam_score)
end
else
learn_spam = verdict == 'spam' or verdict == 'junk'
if not learn_spam then
skip_reason = string.format('verdict: %s',
verdict)
end
end
if train_opts.ham_score then
learn_ham = score <= train_opts.ham_score
if not learn_ham then
skip_reason = string.format('score > ham_score: %f > %f',
score, train_opts.ham_score)
end
else
learn_ham = verdict == 'ham'
if not learn_ham then
skip_reason = string.format('verdict: %s',
verdict)
end
end
elseif not manual_train then
if train_opts.store_pool_only then
local ucl = require "ucl"
learn_ham = false
learn_spam = false
-- Explicitly store tokens in cache (use async collector if providers configured)
local function after_collect(vec)
if not vec then
vec = neural_common.result_to_vector(task, set)
end
task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
end
if rule.providers and #rule.providers > 0 then
neural_common.collect_features_async(task, rule, set, 'train', after_collect)
else
after_collect(nil)
end
skip_reason = 'store_pool_only has been set'
end
end
if learn_spam or learn_ham then
local learn_type
if learn_spam then
learn_type = 'spam'
else
learn_type = 'ham'
end
local function vectors_len_cb(err, data)
if not err and type(data) == 'table' then
local nspam, nham = data[1], data[2]
if manual_train or neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
local function store_train_vec(vec)
if not vec then
lua_util.debugm(N, task, "no vector collected for training")
return
end
local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
-- For manual training:
-- - LLM-only mode: use pending key (embedding dims may vary between versions)
-- - Symbols-only or hybrid (LLM+symbols): use versioned key (dimension is stable)
local target_key
if manual_train and has_llm_provider and not has_symbols_provider then
target_key = neural_common.pending_train_key(rule, set) .. '_' .. learn_type .. '_set'
else
target_key = (set.training_profile or set.ann).redis_key .. '_' .. learn_type .. '_set'
end
local function learn_vec_cb(redis_err)
if redis_err then
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
rule.prefix, set.name, redis_err)
else
lua_util.debugm(N, task,
"add train data for ANN rule " ..
"%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
rule.prefix, set.name, learn_type, #vec, target_key, #str)
end
end
lua_redis.redis_make_request(task,
rule.redis,
nil,
true, -- is write
learn_vec_cb, --callback
'SADD', -- command
{ target_key, str } -- arguments
)
end
if rule.providers and #rule.providers > 0 then
-- Use async feature collection with providers, same as inference
neural_common.collect_features_async(task, rule, set, 'train', store_train_vec)
else
-- Traditional symbol-based vector
local vec = neural_common.result_to_vector(task, set)
store_train_vec(vec)
end
else
lua_util.debugm(N, task,
"do not add %s train data for ANN rule " ..
"%s:%s",
learn_type, rule.prefix, set.name)
end
else
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
rule.prefix, set.name, err)
elseif type(data) == 'string' then
-- nil return value
rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
learn_type, rule.prefix, set.name, (set.training_profile or set.ann).redis_key, data)
else
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
'please remove this key from Redis manually if you perform upgrade from the previous version',
rule.prefix, set.name, (set.training_profile or set.ann).redis_key, type(data))
end
end
end
-- Check if we can learn
-- For manual training, bypass can_store_vectors check (it may not be set yet)
if set.can_store_vectors or manual_train then
if not set.ann and not set.training_profile then
-- No ANN and no best-known profile discovered by process_existing_ann
-- yet — bootstrap a fresh profile for the current configuration.
set.ann = new_ann_profile(task, rule, set, 0)
lua_util.debugm(N, task,
'requested new profile for %s, no ann/training target (manual_train=%s)',
set.name, manual_train)
end
lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
{ task = task, is_write = false },
vectors_len_cb,
{
(set.training_profile or set.ann).redis_key,
})
else
lua_util.debugm(N, task,
'do not push data: train condition not satisfied; reason: not checked existing ANNs')
end
else
lua_util.debugm(N, task,
'do not push data to key %s: train condition not satisfied; reason: %s',
(set.training_profile or set.ann or {}).redis_key,
skip_reason)
end
end
--- Offline training logic
-- Utility to extract and split saved training vectors to a table of tables
local function process_training_vectors(data)
return fun.totable(fun.map(function(tok)
local _, str = rspamd_util.zstd_decompress(tok)
return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
end, data))
end
-- This function does the following:
-- * Tries to lock ANN
-- * Loads spam and ham vectors (from versioned key AND pending key)
-- * Spawn learning process
local function do_train_ann(worker, ev_base, rule, set, ann_key)
-- Check early to prevent concurrent training
if set.learning_spawned then
lua_util.debugm(N, rspamd_config, 'do_train_ann: training already in progress for %s:%s, skipping',
rule.prefix, set.name)
return
end
local spam_elts = {}
local ham_elts = {}
local pending_key = neural_common.pending_train_key(rule, set)
lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s pending=%s',
rule.prefix, set.name, ann_key, pending_key)
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
ann_key, err)
-- Unlock on error
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
true, -- is write
neural_common.gen_unlock_cb(rule, set, ann_key), --callback
'HDEL', -- command
{ ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
ham_elts = process_training_vectors(data)
neural_common.spawn_train({
worker = worker,
ev_base = ev_base,
rule = rule,
set = set,
ann_key = ann_key,
ham_vec = ham_elts,
spam_vec = spam_elts,
pending_key = pending_key
})
end
end
-- Spam vectors received
local function redis_spam_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
ann_key, err)
-- Unlock ANN on error
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
true, -- is write
neural_common.gen_unlock_cb(rule, set, ann_key), --callback
'HDEL', -- command
{ ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
spam_elts = process_training_vectors(data)
-- Now get ham vectors from both versioned and pending keys
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
redis_ham_cb, --callback
'SUNION', -- command (union of sets)
{ ann_key .. '_ham_set', pending_key .. '_ham_set' }
)
end
end
local function redis_lock_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
ann_key, err)
elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
-- Fetch from both versioned key and pending key using SUNION
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
redis_spam_cb, --callback
'SUNION', -- command (union of sets)
{ ann_key .. '_spam_set', pending_key .. '_spam_set' }
)
rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s, pending %s) for learning',
rule.prefix, set.name, ann_key, pending_key)
else
local lock_tm = tonumber(data[1])
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
'locked by another host %s at %s', rule.prefix, set.name, ann_key,
data[2], os.date('%c', lock_tm))
end
end
-- Check if we are already learning this network
if set.learning_spawned then
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
ann_key)
return
end
-- Call Redis script that tries to acquire a lock
-- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
-- ANN is locked by another host (or a process, meh)
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
{ ev_base = ev_base, is_write = true },
redis_lock_cb,
{
ann_key,
tostring(os.time()),
tostring(math.floor(math.max(10.0, rule.watch_interval * 2))),
rspamd_util.get_hostname()
})
end
-- This function loads new ann from Redis
-- This is based on `profile` attribute.
-- ANN is loaded from `profile.redis_key`
-- Rank of `profile` key is also increased, unfortunately, it means that we need to
-- serialize profile one more time and set its rank to the current time
-- set.ann fields are set according to Redis data received
local function load_new_ann(rule, ev_base, set, profile, min_diff)
local ann_key = profile.redis_key
local function data_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
ann_key, err)
else
if type(data) == 'table' then
if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
local _err, ann_data = rspamd_util.zstd_decompress(data[1])
local ann
if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
rule.prefix .. ':' .. set.name, ann_key, _err)
return
else
ann = rspamd_kann.load(ann_data)
if ann then
set.ann = {
digest = profile.digest,
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
redis_key = profile.redis_key,
providers_digest = profile.providers_digest,
}
local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
set.ann.ann = ann -- To avoid serialization
local function rank_cb(_, _)
-- TODO: maybe add some logging
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
true, -- is write
rank_cb, --callback
'ZADD', -- command
{ set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
rspamd_logger.infox(rspamd_config,
'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
rule.prefix, set.name, ann_key, #data[1], profile.version)
else
rspamd_logger.errx(rspamd_config,
'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
rule.prefix, set.name, ann_key)
end
end
else
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
if rule.roc_enabled then
local ucl = require "ucl"
local parser = ucl.parser()
local ok, parse_err = parser:parse_text(data[2])
assert(ok, parse_err)
local roc_thresholds = parser:get_object()
set.ann.roc_thresholds = roc_thresholds
rspamd_logger.infox(rspamd_config,
'loaded ROC thresholds for %s:%s; version=%s',
rule.prefix, set.name, profile.version)
rspamd_logger.debugx(rspamd_config, "ROC thresholds: %s", roc_thresholds)
end
end
if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then
-- PCA table
local _err, pca_data = rspamd_util.zstd_decompress(data[3])
if pca_data then
if rule.max_inputs then
-- We can use PCA
set.ann.pca = rspamd_tensor.load(pca_data)
rspamd_logger.infox(rspamd_config,
'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
rule.prefix, set.name, ann_key, #data[3], profile.version)
else
-- no need in pca, why is it there?
rspamd_logger.warnx(rspamd_config,
'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
rule.prefix, set.name, ann_key)
end
else
-- pca can be missing merely if we have no max_inputs
if rule.max_inputs then
rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
rule.prefix, set.name, ann_key, _err)
set.ann.ann = nil
else
-- It is okay
set.ann.pca = nil
end
end
end
-- Providers meta (optional)
if set.ann and set.ann.ann and type(data[4]) == 'userdata' and data[4].cookie == text_cookie then
local ucl = require "ucl"
local parser = ucl.parser()
local ok = parser:parse_text(data[4])
if ok then
set.ann.providers_meta = parser:get_object()
end
end
-- Normalization stats (optional)
if set.ann and set.ann.ann and type(data[5]) == 'userdata' and data[5].cookie == text_cookie then
local ucl = require "ucl"
local parser = ucl.parser()
local ok = parser:parse_text(data[5])
if ok then
set.ann.norm_stats = parser:get_object()
end
end
else
lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
end
end
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
data_cb, --callback
'HMGET', -- command
{ ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
{ opaque_data = true }
)
end
--- External model support functions
-- Apply loaded external model to settings element
-- @param rule neural rule configuration
-- @param set settings element
-- @param model parsed external model data
-- @param ev_base event base (optional, for storing base model)
local function apply_external_model(rule, set, model, ev_base)
local ext_cfg = rule.external_model
if not ext_cfg or not model then
return false
end
-- Load external ANN
local ext_ann, ann_err = neural_external.load_ann(model)
if not ext_ann then
rspamd_logger.errx(rspamd_config, 'failed to load external ANN for %s:%s: %s',
rule.prefix, set.name, ann_err or "unknown")
return false
end
-- Check if we have a local ANN to merge with
if set.ann and set.ann.ann then
-- Check architecture compatibility
local ok = ext_ann:is_compatible(set.ann.ann)
if not ok then
rspamd_logger.warnx(rspamd_config,
'external ANN architecture incompatible with local ANN for %s:%s, using external only',
rule.prefix, set.name)
set.ann.ann = ext_ann
set.ann.version = model.model_version or 1
set.ann.external_version = model.model_version
set.ann.external_source = ext_cfg.url
return true
end
-- Merge weights (modifies ext_ann in place, returns boolean)
-- C merge: w_dst = (1-a)*w_dst + a*w_src, so to get alpha*ext + (1-alpha)*local
-- we pass (1 - alpha) as the C alpha parameter
local alpha = ext_cfg.merge_alpha or 0.5
local merge_ok, merge_err = ext_ann:merge_weights(set.ann.ann, 1.0 - alpha)
if not merge_ok then
rspamd_logger.errx(rspamd_config, 'failed to merge ANNs for %s:%s: %s',
rule.prefix, set.name, merge_err or "unknown")
return false
end
rspamd_logger.infox(rspamd_config,
'merged external model (version=%s, alpha=%s) with local ANN for %s:%s',
model.model_version, alpha, rule.prefix, set.name)
-- Update ANN reference (merge_weights modifies ext_ann in place)
set.ann.ann = ext_ann
set.ann.version = (set.ann.version or 0) + 1
set.ann.external_version = model.model_version
set.ann.external_source = ext_cfg.url
-- Store base model for future re-merge
if ev_base then
neural_external.store_base_model(rule.redis, ev_base, set.ann.redis_key, model, function(store_err)
if store_err then
rspamd_logger.warnx(rspamd_config, 'failed to store base model: %s', store_err)
end
end)
end
else
-- No local ANN, just use external
rspamd_logger.infox(rspamd_config,
'loaded external model (version=%s) as initial ANN for %s:%s',
model.model_version, rule.prefix, set.name)
set.ann = {
version = model.model_version or 1,
redis_key = neural_common.new_ann_key(rule, set, model.model_version or 1),
external_version = model.model_version,
external_source = ext_cfg.url,
ann = ext_ann,
providers_digest = ext_cfg.providers_digest,
digest = 'external:' .. (model.model_version or '0'),
symbols = set.symbols,
distance = 0,
}
-- Store base model for future re-merge
if ev_base then
neural_external.store_base_model(rule.redis, ev_base, set.ann.redis_key, model, function(store_err)
if store_err then
rspamd_logger.warnx(rspamd_config, 'failed to store base model: %s', store_err)
end
end)
end
end
-- Load PCA if present
local pca = neural_external.load_pca(model)
if pca then
set.ann.pca = pca
end
-- Copy normalization stats
if model.norm_stats then
set.ann.norm_stats = model.norm_stats
end
-- Copy ROC thresholds
if model.roc_thresholds then
set.ann.roc_thresholds = model.roc_thresholds
end
-- Update external model state
ext_cfg.last_version = model.model_version
ext_cfg.loaded = true
return true
end
-- Register external model map for a rule
-- This should be called at config time
-- @param rule neural rule configuration
-- @return boolean success
local function register_external_model_map(rule)
local ext_cfg = rule.external_model
if not ext_cfg or not ext_cfg.url then
return false
end
-- Store rule reference for callbacks
local rule_ref = rule
-- Map callback: called when external model is loaded/reloaded
local function on_model_load(model, err)
if err then
rspamd_logger.errx(rspamd_config, 'external model load failed for %s: %s',
rule_ref.prefix, err)
return
end
-- Apply model to all settings
for _, set in pairs(rule_ref.settings) do
if type(set) == 'table' then
apply_external_model(rule_ref, set, model, nil)
end
end
end
return neural_external.register_model_map(rspamd_config, rule, ext_cfg.providers_digest, on_model_load)
end
-- Check external model updates (called periodically by map infrastructure)
-- This is now mostly handled by the map's automatic reload mechanism
local function check_external_model(worker, cfg, ev_base, rule)
local ext_cfg = rule.external_model
if not ext_cfg then
return
end
-- Check if we have a cached model from the map
local cached_model = neural_external.get_cached_model(ext_cfg.url)
if cached_model and cached_model.model_version ~= ext_cfg.last_version then
rspamd_logger.infox(cfg, 'external model updated for %s: version %s -> %s',
rule.prefix, ext_cfg.last_version or 0, cached_model.model_version)
-- Apply to all settings
for _, set in pairs(rule.settings) do
if type(set) == 'table' then
apply_external_model(rule, set, cached_model, ev_base)
end
end
end
end
-- Used to check an element in Redis serialized as JSON
-- for some specific rule + some specific setting
-- This function tries to load more fresh or more specific ANNs in lieu of
-- the existing ones.
-- Use this function to load ANNs as `callback` parameter for `check_anns` function
local function process_existing_ann(_, ev_base, rule, set, profiles)
local has_providers = rule.providers and #rule.providers > 0
local current_providers_digest = has_providers and
neural_common.providers_config_digest(rule.providers, rule) or nil
local min_diff = math.huge
local sel_elt
lua_util.debugm(N, rspamd_config,
'process_existing_ann: have %s profiles for %s:%s (providers_digest=%s)',
type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name,
current_providers_digest or 'none')
for _, elt in fun.iter(profiles) do
local compatible, dist = neural_common.is_profile_compatible(
rule, set, elt, current_providers_digest)
if compatible then
-- Prefer smaller distance; tie-break on higher version
if dist < min_diff
or (dist == min_diff and sel_elt and (elt.version or 0) > (sel_elt.version or 0)) then
min_diff = dist
sel_elt = elt
end
end
end
if sel_elt then
-- Track the best-known profile as the training target independently of
-- the currently loaded ANN (set.ann). This lets training vectors flow
-- into a freshly-registered profile even while its ANN hasn't been
-- trained yet — otherwise workers keep writing to the last-loaded ANN's
-- key and the new profile's training sets stay empty forever.
set.training_profile = {
redis_key = sel_elt.redis_key,
version = sel_elt.version,
digest = sel_elt.digest,
symbols = sel_elt.symbols,
distance = min_diff,
providers_digest = sel_elt.providers_digest,
}
-- We can load element from ANN
if set.ann then
-- Providers schema acts as the dominant identity when configured: even
-- if the symbol-digest portion drifted (symcache shift), a matching
-- providers_digest means the vector shape (and therefore the trained
-- weights) are still valid. Reload purely on version freshness in
-- that case.
local providers_compatible = has_providers and current_providers_digest
and set.ann.providers_digest == current_providers_digest
and sel_elt.providers_digest == current_providers_digest
if set.ann.digest == sel_elt.digest then
-- Same ANN, check version
if (set.ann.version or 0) < (sel_elt.version or 0) then
rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
'our version = %s, remote version = %s',
rule.prefix .. ':' .. set.name,
set.ann.version,
sel_elt.version)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
'our version = %s, remote version = %s',
rule.prefix .. ':' .. set.name,
set.ann.version,
sel_elt.version)
end
elseif providers_compatible then
if (sel_elt.version or 0) > (set.ann.version or 0) then
rspamd_logger.infox(rspamd_config,
'providers schema matches for %s; reload newer version %s (ours = %s)',
rule.prefix .. ':' .. set.name,
sel_elt.version, set.ann.version)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config,
'providers schema matches for %s; our version %s >= remote %s, no reload',
rule.prefix .. ':' .. set.name,
set.ann.version, sel_elt.version)
end
else
-- We have some different ANN, so we need to compare distance
if (set.ann.distance or math.huge) > min_diff then
rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
'our distance = %s, remote distance = %s',
rule.prefix .. ':' .. set.name,
set.ann.distance,
min_diff)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
'our distance = %s, remote distance = %s',
rule.prefix .. ':' .. set.name,
set.ann.distance,
min_diff)
end
end
else
-- We have no ANN, load new one
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
end
end
if sel_elt then
lua_util.debugm(N, rspamd_config, 'process_existing_ann: selected profile version=%s key=%s', sel_elt.version,
sel_elt.redis_key)
else
lua_util.debugm(N, rspamd_config, 'process_existing_ann: no suitable profile found')
end
end
-- This function checks all profiles and selects if we can train our
-- ANN. By our we mean that it has exactly the same symbols in profile.
-- Use this function to train ANN as `callback` parameter for `check_anns` function
local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
local has_providers = rule.providers and #rule.providers > 0
local current_providers_digest = has_providers and
neural_common.providers_config_digest(rule.providers, rule) or nil
local sel_elt
local lens = {
spam = 0,
ham = 0,
}
lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: %s profiles for %s:%s',
type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
-- Strict match: training data accumulated against an existing profile
-- must come from a compatible vector schema. is_profile_compatible
-- returns dist=0 when symbols are irrelevant (disable_symbols_input) or
-- when symbol-lists actually match.
for _, elt in fun.iter(profiles) do
local compatible, dist = neural_common.is_profile_compatible(
rule, set, elt, current_providers_digest)
if compatible and dist == 0 then
sel_elt = elt
break
end
end
if sel_elt then
-- We have our ANN and that's train vectors, check if we can learn
local ann_key = sel_elt.redis_key
-- Check if we need to train ann
if rule.train.store_set_only then
lua_util.debugm(N, rspamd_config, "skiped check if ANN %s needs to be trained due to store_set_only", ann_key)
return
end
local pending_key = neural_common.pending_train_key(rule, set)
lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained",
ann_key, pending_key)
local function initiate_train()
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s (pending %s) after %s required learn vectors',
ann_key, pending_key, lens)
lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: initiating train for key=%s spam=%s ham=%s', ann_key,
lens.spam or -1, lens.ham or -1)
do_train_ann(worker, ev_base, rule, set, ann_key)
end
-- Final check after all vectors are counted
local function maybe_initiate_train()
local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
lua_util.debugm(N, rspamd_config,
'final vector count for ANN %s: spam=%s ham=%s (min=%s max=%s required=%s)',
ann_key, lens.spam, lens.ham, min_len, max_len, rule.train.max_trains)
if rule.train.learn_mode == 'balanced' then
local len_bias_check_pred = function(_, l)
return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
end
if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
initiate_train()
else
lua_util.debugm(N, rspamd_config,
'cannot learn ANN %s: balanced mode requires more vectors (has %s)',
ann_key, lens)
end
else
-- Probabilistic mode
if min_len > 0 and max_len >= rule.train.max_trains then
initiate_train()
else
lua_util.debugm(N, rspamd_config,
'cannot learn ANN %s: need min_len > 0 and max_len >= %s (has %s)',
ann_key, rule.train.max_trains, lens)
end
end
end
-- Callback that adds count from pending key and continues
local function add_pending_cb(cont_cb, what)
return function(err, data)
if not err and (type(data) == 'number' or type(data) == 'string') then
local pending_count = tonumber(data) or 0
lens[what] = (lens[what] or 0) + pending_count
lua_util.debugm(N, rspamd_config, 'added %s pending %s vectors, total now %s',
pending_count, what, lens[what])
end
cont_cb()
end
end
-- Simple callback that just adds versioned count and continues
local function add_versioned_cb(cont_cb, what)
return function(err, data)
if not err and (type(data) == 'number' or type(data) == 'string') then
local count = tonumber(data) or 0
lens[what] = (lens[what] or 0) + count
lua_util.debugm(N, rspamd_config, 'added %s versioned %s vectors, total now %s',
count, what, lens[what])
end
cont_cb()
end
end
-- Check pending ham, then make final decision
local function check_pending_ham()
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false,
add_pending_cb(maybe_initiate_train, 'ham'),
'SCARD',
{ pending_key .. '_ham_set' }
)
end
-- Check versioned ham, then check pending ham
local function check_ham_len()
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false,
add_versioned_cb(check_pending_ham, 'ham'),
'SCARD',
{ ann_key .. '_ham_set' }
)
end
-- Check pending spam, then check ham
local function check_pending_spam()
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false,
add_pending_cb(check_ham_len, 'spam'),
'SCARD',
{ pending_key .. '_spam_set' }
)
end
-- Check versioned spam, then pending spam
local function check_spam_len()
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false,
add_versioned_cb(check_pending_spam, 'spam'),
'SCARD',
{ ann_key .. '_spam_set' }
)
end
-- Start the chain
check_spam_len()
end
end
-- Used to deserialise ANN element from a list
load_ann_profile = function(element)
local ucl = require "ucl"
local parser = ucl.parser()
local res, ucl_err = parser:parse_string(element)
if not res then
rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
ucl_err)
return nil
else
local profile = parser:get_object()
local checked, schema_err = redis_profile_schema:transform(profile)
if not checked then
rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err)
return nil
end
return checked
end
end
-- Async carryover: look up the most recent zset entry with the same
-- providers_digest and a trained ANN blob, then copy its
-- ann/roc_thresholds/pca/providers_meta/norm_stats fields into the freshly
-- created profile's redis_key. Only runs when the new key has no ANN yet,
-- so this never overwrites a freshly-trained model.
maybe_carryover_ann = function(task, rule, set, new_key, target_providers_digest)
local function zrange_cb(err, data)
if err or type(data) ~= 'table' then
lua_util.debugm(N, task, 'carryover: cannot read zset %s: %s',
set.prefix, err)
return
end
local source_key
for _, raw in ipairs(data) do
local profile = load_ann_profile(raw)
if profile
and profile.providers_digest == target_providers_digest
and profile.redis_key ~= new_key then
source_key = profile.redis_key
break
end
end
if not source_key then
lua_util.debugm(N, task,
'carryover: no prior profile with matching providers_digest for %s:%s',
rule.prefix, set.name)
return
end
local function hmset_cb(hmset_err)
if hmset_err then
rspamd_logger.errx(task,
'carryover: cannot copy ANN from %s to %s: %s',
source_key, new_key, hmset_err)
else
rspamd_logger.infox(task,
'carryover: copied ANN weights from %s into fresh profile %s ' ..
'(providers_digest unchanged)',
source_key, new_key)
end
end
local function hmget_cb(hmget_err, hmget_data)
if hmget_err or type(hmget_data) ~= 'table' then
lua_util.debugm(N, task,
'carryover: HMGET error for %s: %s', source_key, hmget_err)
return
end
if not (type(hmget_data[1]) == 'userdata' and hmget_data[1].cookie == text_cookie) then
lua_util.debugm(N, task,
'carryover: source key %s has no ANN blob', source_key)
return
end
local fields = { 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }
local args = { new_key }
for i, fname in ipairs(fields) do
local v = hmget_data[i]
if type(v) == 'userdata' and v.cookie == text_cookie then
args[#args + 1] = fname
args[#args + 1] = v
end
end
if #args <= 1 then
lua_util.debugm(N, task,
'carryover: nothing to copy from %s', source_key)
return
end
lua_redis.redis_make_request(task,
rule.redis,
nil,
true,
hmset_cb,
'HMSET',
args)
end
local function exists_cb(hex_err, hex_data)
if hex_err then
lua_util.debugm(N, task,
'carryover: HEXISTS error for %s: %s', new_key, hex_err)
return
end
if tonumber(hex_data) == 1 then
lua_util.debugm(N, task,
'carryover: %s already has an ANN, skipping copy', new_key)
return
end
lua_redis.redis_make_request(task,
rule.redis,
nil,
false,
hmget_cb,
'HMGET',
{ source_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' },
{ opaque_data = true })
end
lua_redis.redis_make_request(task,
rule.redis,
nil,
false,
exists_cb,
'HEXISTS',
{ new_key, 'ann' })
end
lua_redis.redis_make_request(task,
rule.redis,
nil,
false,
zrange_cb,
'ZREVRANGE',
{ set.prefix, '0', tostring(settings.max_profiles) })
end
-- Function to check or load ANNs from Redis
local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
for _, set in pairs(rule.settings) do
local function members_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
err)
set.can_store_vectors = true
elseif type(data) == 'table' then
lua_util.debugm(N, cfg, '%s: process element %s:%s (profiles=%s)',
what, rule.prefix, set.name, #data)
-- Use fun.totable to convert iterator to table for Lua 5.4 compatibility
process_callback(worker, ev_base, rule, set, fun.totable(fun.map(load_ann_profile, data)))
set.can_store_vectors = true
else
lua_util.debugm(N, cfg, '%s: no profiles for %s:%s', what, rule.prefix, set.name)
set.can_store_vectors = true
end
end
if type(set) == 'table' then
-- Extract all profiles for some specific settings id
-- Get the last `max_profiles` recently used
-- Select the most appropriate to our profile but it should not differ by more
-- than 30% of symbols
lua_redis.redis_make_request_taskless(ev_base,
cfg,
rule.redis,
nil,
false, -- is write
members_cb, --callback
'ZREVRANGE', -- command
{ set.prefix, '0', tostring(settings.max_profiles) } -- arguments
)
end
end -- Cycle over all settings
return rule.watch_interval
end
-- Function to clean up old ANNs
local function cleanup_anns(rule, cfg, ev_base)
for _, set in pairs(rule.settings) do
local function invalidate_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
err)
elseif type(data) == 'table' then
for _, expired in ipairs(data) do
local profile = load_ann_profile(expired)
rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
rule.prefix .. ':' .. set.name,
profile.redis_key,
profile.version)
end
end
end
if type(set) == 'table' then
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
{ ev_base = ev_base, is_write = true },
invalidate_cb,
{ set.prefix, tostring(settings.max_profiles) })
end
end
end
local function ann_push_vector(task)
if task:has_flag('skip') then
lua_util.debugm(N, task, 'do not push data for skipped task')
return
end
-- Allow manual training via ANN-Train header regardless of allow_local
local manual_train_header = get_ann_train_header(task)
if not settings.allow_local and not manual_train_header and lua_util.is_rspamc_or_controller(task) then
lua_util.debugm(N, task, 'do not push data for manual scan')
return
end
local verdict, score = lua_verdict.get_specific_verdict(N, task)
if verdict == 'passthrough' then
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
verdict, score)
return
end
if score ~= score then
lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
verdict)
return
end
for _, rule in pairs(settings.rules) do
local set = neural_common.get_rule_settings(task, rule)
if set then
ann_push_task_result(rule, task, verdict, score, set)
else
lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
end
end
end
-- Initialization part
if not (neural_common.module_config and type(neural_common.module_config) == 'table')
or not neural_common.redis_params then
rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
lua_util.disable_module(N, "redis")
return
end
local rules = neural_common.module_config['rules']
if not rules then
-- Use legacy configuration
rules = {}
rules['default'] = neural_common.module_config
end
local id = rspamd_config:register_symbol({
name = 'NEURAL_CHECK',
type = 'postfilter,callback',
flags = 'nostat',
priority = lua_util.symbols_priorities.medium,
callback = ann_scores_filter
})
neural_common.settings.rules = {} -- Reset unless validated further in the cycle
if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then
-- Transform to hash for simplicity
settings.blacklisted_symbols = lua_util.list_to_hash(settings.blacklisted_symbols)
end
-- Check all rules
for k, r in pairs(rules) do
local rule_elt = lua_util.override_defaults(neural_common.default_options, r)
rule_elt['redis'] = neural_common.redis_params
rule_elt['anns'] = {} -- Store ANNs here
if not rule_elt.prefix then
rule_elt.prefix = k
end
if not rule_elt.name then
rule_elt.name = k
end
if rule_elt.train.max_train and not rule_elt.train.max_trains then
rule_elt.train.max_trains = rule_elt.train.max_train
end
if not rule_elt.profile then
rule_elt.profile = {}
end
if rule_elt.max_inputs and not has_blas then
rspamd_logger.errx(rspamd_config, 'cannot set max inputs to %s as BLAS is not compiled in',
rule_elt.name, rule_elt.max_inputs)
rule_elt.max_inputs = nil
end
-- Phase 4: basic provider config validation + init
if rule_elt.providers and #rule_elt.providers > 0 then
for i, pcfg in ipairs(rule_elt.providers) do
if not (pcfg.type or pcfg.name) then
rspamd_logger.errx(rspamd_config, 'provider at index %s in rule %s has no type/name; will be ignored', i, k)
end
if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then
rspamd_logger.errx(rspamd_config,
'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
end
-- Call provider init at config time (for map registration etc.)
local prov = neural_common.get_provider(pcfg.type or pcfg.name)
if prov and prov.init then
prov.init(pcfg)
end
end
end
-- External model configuration
if rule_elt.external_model then
local providers_digest = neural_common.providers_config_digest(rule_elt.providers, rule_elt)
rule_elt.external_model = neural_external.create_external_config(rule_elt, providers_digest)
if rule_elt.external_model then
rspamd_logger.infox(rspamd_config, "configured external model for rule %s: url=%s, merge_alpha=%s",
k, rule_elt.external_model.url, rule_elt.external_model.merge_alpha)
end
end
rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
settings.rules[k] = rule_elt
-- Register external model map if configured
if rule_elt.external_model then
register_external_model_map(rule_elt)
end
rspamd_config:set_metric_symbol({
name = rule_elt.symbol_spam,
score = 0.0,
description = 'Neural network SPAM',
group = 'neural'
})
rspamd_config:register_symbol({
name = rule_elt.symbol_spam,
type = 'virtual',
flags = 'nostat',
parent = id
})
rspamd_config:set_metric_symbol({
name = rule_elt.symbol_ham,
score = -0.0,
description = 'Neural network HAM',
group = 'neural'
})
rspamd_config:register_symbol({
name = rule_elt.symbol_ham,
type = 'virtual',
flags = 'nostat',
parent = id
})
end
rspamd_config:register_symbol({
name = 'NEURAL_LEARN',
type = 'idempotent,callback',
flags = 'nostat,explicit_disable,ignore_passthrough',
callback = ann_push_vector
})
-- We also need to deal with settings
rspamd_config:add_post_init(neural_common.process_rules_settings)
-- Add training scripts
for _, rule in pairs(settings.rules) do
neural_common.load_scripts(rule.redis)
-- This function will check ANNs in Redis when a worker is loaded
rspamd_config:add_on_load(function(cfg, ev_base, worker)
if worker:is_scanner() then
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
'try_load_ann')
end)
end
if worker:is_primary_controller() then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
-- Clean old ANNs
cleanup_anns(rule, cfg, ev_base)
-- Check for external model updates
check_external_model(worker, cfg, ev_base, rule)
return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
'try_train_ann')
end)
end
end)
end
-- Register plugin API in rspamd_plugins for user hooks
if rspamd_plugins then
rspamd_plugins['neural'] = rspamd_plugins['neural'] or {}
-- Expose autolearn hooks for user customization
rspamd_plugins['neural'].autolearn = {
-- Register a custom guard that can block learning
-- cb: function(task, learn_type, ctx) -> bool, reason
register_guard = neural_learn.register_guard,
-- Remove a registered guard
unregister_guard = neural_learn.unregister_guard,
-- Configure global autolearn defaults
configure = neural_learn.configure,
-- Check if task qualifies for autolearn
-- Returns: can_learn (bool), reason (string)
can_autolearn = neural_learn.can_autolearn,
-- Get learn type for task based on conditions
-- Returns: 'spam', 'ham', or nil
get_learn_type = neural_learn.get_learn_type,
-- Set autolearn class in mempool (triggers learning in idempotent callback)
set_autolearn_class = neural_learn.set_autolearn_class,
-- Get autolearn class from mempool
get_autolearn_class = neural_learn.get_autolearn_class,
}
end