diff --git a/mm/slab.h b/mm/slab.h index eab7d1b88570f..ec4908fe6efb6 100644 --- a/mm/slab.h +++ b/mm/slab.h @@ -614,8 +614,10 @@ static inline void cache_random_seq_destroy(struct kmem_cache *cachep) { } static inline bool slab_want_init_on_alloc(gfp_t flags, struct kmem_cache *c) { if (static_branch_unlikely(&init_on_alloc)) { +#ifndef CONFIG_SLUB if (c->ctor) return false; +#endif if (c->flags & (SLAB_TYPESAFE_BY_RCU | SLAB_POISON)) return flags & __GFP_ZERO; return true; diff --git a/mm/slub.c b/mm/slub.c index 67344ff6a893d..04ddfb5c0080a 100644 --- a/mm/slub.c +++ b/mm/slub.c @@ -1639,9 +1639,10 @@ static inline bool slab_free_freelist_hook(struct kmem_cache *s, * need to show a valid freepointer to check_object(). * * Note that doing this for all caches (not just ctor - * ones, which have s->offset != NULL)) causes a GPF, - * due to KASAN poisoning and the way set_freepointer() - * eventually dereferences the freepointer. + * ones, which have s->offset >= object_size)) causes a + * GPF, due to KASAN poisoning and the way + * set_freepointer() eventually dereferences the + * freepointer. */ set_freepointer(s, object, NULL); } @@ -2956,8 +2957,14 @@ static __always_inline void *slab_alloc_node(struct kmem_cache *s, if (s->ctor) s->ctor(object); kasan_poison_object_data(s, object); - } else if (unlikely(slab_want_init_on_alloc(gfpflags, s)) && object) + } else if (unlikely(slab_want_init_on_alloc(gfpflags, s)) && object) { memset(object, 0, s->object_size); + if (s->ctor) { + kasan_unpoison_object_data(s, object); + s->ctor(object); + kasan_poison_object_data(s, object); + } + } if (object) { check_canary(s, object, s->random_inactive); @@ -3415,8 +3422,14 @@ int kmem_cache_alloc_bulk(struct kmem_cache *s, gfp_t flags, size_t size, } else if (unlikely(slab_want_init_on_alloc(flags, s))) { int j; - for (j = 0; j < i; j++) + for (j = 0; j < i; j++) { memset(p[j], 0, s->object_size); + if (s->ctor) { + kasan_unpoison_object_data(s, p[j]); + s->ctor(p[j]); + kasan_poison_object_data(s, p[j]); + } + } } for (k = 0; k < i; k++) {