diff --git a/mm/memcontrol.c b/mm/memcontrol.c index 2d4d65f25fecd4..1c05c3832099cd 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -3062,23 +3062,81 @@ bool __memcg_slab_post_alloc_hook(struct kmem_cache *s, struct list_lru *lru, return false; } - if (obj_cgroup_charge(objcg, flags, size * obj_full_size(s))) - return false; + { + const size_t obj_size = obj_full_size(s); + /* + * Cap run length to prevent integer overflow when computing + * batch_bytes, which is passed as int to mod_objcg_state(). + */ + const size_t max_run = min_t(size_t, + (INT_MAX - PAGE_SIZE) / obj_size, + size); - for (i = 0; i < size; i++) { - slab = virt_to_slab(p[i]); + for (i = 0; i < size; ) { + struct pglist_data *pgdat; + size_t run_end; + int batch_bytes; - if (!slab_obj_exts(slab) && - alloc_slab_obj_exts(slab, s, flags, false)) { - obj_cgroup_uncharge(objcg, obj_full_size(s)); - continue; - } + slab = virt_to_slab(p[i]); - off = obj_to_index(s, slab, p[i]); - obj_cgroup_get(objcg); - slab_obj_exts(slab)[off].objcg = objcg; - mod_objcg_state(objcg, slab_pgdat(slab), - cache_vmstat_idx(s), obj_full_size(s)); + if (!slab_obj_exts(slab) && + alloc_slab_obj_exts(slab, s, flags, false)) { + i++; + continue; + } + + pgdat = slab_pgdat(slab); + run_end = i + 1; + + /* + * Scan ahead for objects on the same pgdat node so we + * can batch the memcg charge into a single call. + */ + while (run_end < size && (run_end - i) < max_run) { + struct slab *slab_j = virt_to_slab(p[run_end]); + + if (slab_pgdat(slab_j) != pgdat) + break; + + if (!slab_obj_exts(slab_j) && + alloc_slab_obj_exts(slab_j, s, flags, false)) + break; + + run_end++; + } + + /* + * If we fail and size is 1, memcg_alloc_abort_single() + * will just free the object, which is ok as we have not + * assigned objcg to its obj_ext yet. + * + * For larger sizes, kmem_cache_free_bulk() will uncharge + * any objects that were already charged and obj_ext + * assigned. + */ + batch_bytes = (int)(obj_size * (run_end - i)); + if (obj_cgroup_charge(objcg, flags, batch_bytes)) + return false; + + /* + * Assign obj_ext for each object in this run. + * Reuse 'slab' for the first object (already computed + * above) to avoid a redundant virt_to_slab() call. + */ + off = obj_to_index(s, slab, p[i]); + obj_cgroup_get(objcg); + slab_obj_exts(slab)[off].objcg = objcg; + + while (++i < run_end) { + slab = virt_to_slab(p[i]); + off = obj_to_index(s, slab, p[i]); + obj_cgroup_get(objcg); + slab_obj_exts(slab)[off].objcg = objcg; + } + + mod_objcg_state(objcg, pgdat, cache_vmstat_idx(s), + batch_bytes); + } } return true;