Skip to content

Commit

Permalink
use doctrine.middleware tag for registering middlewares (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmaicher authored May 28, 2024
1 parent e193f5e commit 1f81a28
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 106 deletions.
2 changes: 1 addition & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"require": {
"php": "^7.4 || ^8.0",
"doctrine/dbal": "^3.3 || ^4.0",
"doctrine/doctrine-bundle": "^2.2.2",
"doctrine/doctrine-bundle": "^2.11.0",
"psr/cache": "^1.0 || ^2.0 || ^3.0",
"symfony/cache": "^5.4 || ^6.3 || ^7.0",
"symfony/framework-bundle": "^5.4 || ^6.3 || ^7.0"
Expand Down
8 changes: 6 additions & 2 deletions src/DAMA/DoctrineTestBundle/DAMADoctrineTestBundle.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

namespace DAMA\DoctrineTestBundle;

use DAMA\DoctrineTestBundle\DependencyInjection\DoctrineTestCompilerPass;
use DAMA\DoctrineTestBundle\DependencyInjection\AddMiddlewaresCompilerPass;
use DAMA\DoctrineTestBundle\DependencyInjection\ModifyDoctrineConfigCompilerPass;
use Symfony\Component\DependencyInjection\Compiler\PassConfig;
use Symfony\Component\DependencyInjection\ContainerBuilder;
use Symfony\Component\HttpKernel\Bundle\Bundle;
Expand All @@ -13,6 +14,9 @@ public function build(ContainerBuilder $container): void
{
parent::build($container);
// lower priority than CacheCompatibilityPass from DoctrineBundle
$container->addCompilerPass(new DoctrineTestCompilerPass(), PassConfig::TYPE_BEFORE_OPTIMIZATION, -1);
$container->addCompilerPass(new ModifyDoctrineConfigCompilerPass(), PassConfig::TYPE_BEFORE_OPTIMIZATION, -1);

// higher priority than MiddlewaresPass from DoctrineBundle
$container->addCompilerPass(new AddMiddlewaresCompilerPass(), PassConfig::TYPE_BEFORE_OPTIMIZATION, 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
<?php

declare(strict_types=1);

namespace DAMA\DoctrineTestBundle\DependencyInjection;

use DAMA\DoctrineTestBundle\Doctrine\DBAL\Middleware;
use Symfony\Component\DependencyInjection\Compiler\CompilerPassInterface;
use Symfony\Component\DependencyInjection\ContainerBuilder;

final class AddMiddlewaresCompilerPass implements CompilerPassInterface
{
public const TRANSACTIONAL_BEHAVIOR_ENABLED_CONNECTIONS = 'dama.doctrine_test.transactional_behavior_enabled_connections';

public function process(ContainerBuilder $container): void
{
/** @var array<string, mixed> $connections */
$connections = $container->getParameter('doctrine.connections');
$connectionNames = array_keys($connections);
$transactionalBehaviorEnabledConnections = $this->getTransactionEnabledConnectionNames($container, $connectionNames);
$container->getParameterBag()->set(self::TRANSACTIONAL_BEHAVIOR_ENABLED_CONNECTIONS, $transactionalBehaviorEnabledConnections);

foreach ($transactionalBehaviorEnabledConnections as $name) {
$middlewareDefinition = $container->register(sprintf('dama.doctrine.dbal.middleware.%s', $name), Middleware::class);
$middlewareDefinition->addTag('doctrine.middleware', ['connection' => $name, 'priority' => 100]);
}

$container->getParameterBag()->remove('dama.'.Configuration::ENABLE_STATIC_CONNECTION);
}

/**
* @param string[] $connectionNames
*
* @return string[]
*/
private function getTransactionEnabledConnectionNames(ContainerBuilder $container, array $connectionNames): array
{
/** @var bool|array<string, bool> $enableStaticConnectionsConfig */
$enableStaticConnectionsConfig = $container->getParameter('dama.'.Configuration::ENABLE_STATIC_CONNECTION);

if (is_array($enableStaticConnectionsConfig)) {
$this->validateConnectionNames(array_keys($enableStaticConnectionsConfig), $connectionNames);
}

$enabledConnections = [];

foreach ($connectionNames as $name) {
if ($enableStaticConnectionsConfig === true
|| isset($enableStaticConnectionsConfig[$name]) && $enableStaticConnectionsConfig[$name] === true
) {
$enabledConnections[] = $name;
}
}

return $enabledConnections;
}

/**
* @param string[] $configNames
* @param string[] $existingNames
*/
private function validateConnectionNames(array $configNames, array $existingNames): void
{
$unknown = array_diff($configNames, $existingNames);

if (count($unknown)) {
throw new \InvalidArgumentException(sprintf('Unknown doctrine dbal connection name(s): %s.', implode(', ', $unknown)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@
namespace DAMA\DoctrineTestBundle\DependencyInjection;

use DAMA\DoctrineTestBundle\Doctrine\Cache\Psr6StaticArrayCache;
use DAMA\DoctrineTestBundle\Doctrine\DBAL\Middleware;
use Doctrine\Common\Cache\Cache;
use Doctrine\DBAL\Connection;
use Psr\Cache\CacheItemPoolInterface;
use Symfony\Component\DependencyInjection\ChildDefinition;
use Symfony\Component\DependencyInjection\Compiler\CompilerPassInterface;
use Symfony\Component\DependencyInjection\ContainerBuilder;
use Symfony\Component\DependencyInjection\Definition;
use Symfony\Component\DependencyInjection\Reference;

class DoctrineTestCompilerPass implements CompilerPassInterface
final class ModifyDoctrineConfigCompilerPass implements CompilerPassInterface
{
public function process(ContainerBuilder $container): void
{
$container->register('dama.doctrine.dbal.middleware', Middleware::class);
$cacheNames = [];

if ($container->getParameter('dama.'.Configuration::STATIC_META_CACHE)) {
Expand All @@ -31,7 +28,11 @@ public function process(ContainerBuilder $container): void
/** @var array<string, mixed> $connections */
$connections = $container->getParameter('doctrine.connections');
$connectionNames = array_keys($connections);
$transactionalBehaviorEnabledConnections = $this->getTransactionEnabledConnectionNames($container, $connectionNames);

/** @var string[] $transactionalBehaviorEnabledConnections */
$transactionalBehaviorEnabledConnections = $container->getParameter(
AddMiddlewaresCompilerPass::TRANSACTIONAL_BEHAVIOR_ENABLED_CONNECTIONS,
);
$connectionKeys = $this->getConnectionKeys($container, $connectionNames);

foreach ($connectionNames as $name) {
Expand All @@ -56,10 +57,10 @@ public function process(ContainerBuilder $container): void
}
}

$container->getParameterBag()->remove('dama.'.Configuration::ENABLE_STATIC_CONNECTION);
$container->getParameterBag()->remove('dama.'.Configuration::STATIC_META_CACHE);
$container->getParameterBag()->remove('dama.'.Configuration::STATIC_QUERY_CACHE);
$container->getParameterBag()->remove('dama.'.Configuration::CONNECTION_KEYS);
$container->getParameterBag()->remove(AddMiddlewaresCompilerPass::TRANSACTIONAL_BEHAVIOR_ENABLED_CONNECTIONS);
}

/**
Expand All @@ -79,24 +80,6 @@ private function modifyConnectionService(ContainerBuilder $container, $connectio
0,
$this->getModifiedConnectionOptions($connectionOptions, $connectionKey, $name),
);

$connectionConfig = $container->getDefinition(sprintf('doctrine.dbal.%s_connection.configuration', $name));
$methodCalls = $connectionConfig->getMethodCalls();
$middlewareRef = new Reference('dama.doctrine.dbal.middleware');
$hasMiddlewaresMethodCall = false;
foreach ($methodCalls as &$methodCall) {
if ($methodCall[0] === 'setMiddlewares') {
$hasMiddlewaresMethodCall = true;
// our middleware needs to be the first one here so we wrap the "native" driver
$methodCall[1][0] = array_merge([$middlewareRef], $methodCall[1][0]);
}
}

if (!$hasMiddlewaresMethodCall) {
$methodCalls[] = ['setMiddlewares', [[$middlewareRef]]];
}

$connectionConfig->setMethodCalls($methodCalls);
}

/**
Expand Down Expand Up @@ -166,33 +149,6 @@ private function registerStaticCache(
$container->setDefinition($cacheServiceId, $cache);
}

/**
* @param string[] $connectionNames
*
* @return string[]
*/
private function getTransactionEnabledConnectionNames(ContainerBuilder $container, array $connectionNames): array
{
/** @var bool|array<string, bool> $enableStaticConnectionsConfig */
$enableStaticConnectionsConfig = $container->getParameter('dama.'.Configuration::ENABLE_STATIC_CONNECTION);

if (is_array($enableStaticConnectionsConfig)) {
$this->validateConnectionNames(array_keys($enableStaticConnectionsConfig), $connectionNames);
}

$enabledConnections = [];

foreach ($connectionNames as $name) {
if ($enableStaticConnectionsConfig === true
|| isset($enableStaticConnectionsConfig[$name]) && $enableStaticConnectionsConfig[$name] === true
) {
$enabledConnections[] = $name;
}
}

return $enabledConnections;
}

/**
* @param string[] $configNames
* @param string[] $existingNames
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
<?php

declare(strict_types=1);

namespace Tests\DAMA\DoctrineTestBundle\DependencyInjection;

use DAMA\DoctrineTestBundle\DependencyInjection\AddMiddlewaresCompilerPass;
use DAMA\DoctrineTestBundle\DependencyInjection\DAMADoctrineTestExtension;
use DAMA\DoctrineTestBundle\DependencyInjection\DoctrineTestCompilerPass;
use DAMA\DoctrineTestBundle\DependencyInjection\ModifyDoctrineConfigCompilerPass;
use DAMA\DoctrineTestBundle\Doctrine\Cache\Psr6StaticArrayCache;
use Doctrine\Bundle\DoctrineBundle\ConnectionFactory;
use Doctrine\DBAL\Configuration;
Expand All @@ -13,9 +16,8 @@
use Symfony\Component\DependencyInjection\ChildDefinition;
use Symfony\Component\DependencyInjection\ContainerBuilder;
use Symfony\Component\DependencyInjection\Definition;
use Symfony\Component\DependencyInjection\Reference;

class DoctrineTestCompilerPassTest extends TestCase
class CompilerPassesTest extends TestCase
{
private const CACHE_SERVICE_IDS = [
'doctrine.orm.a_metadata_cache',
Expand Down Expand Up @@ -55,11 +57,6 @@ public function testProcess(array $config, callable $assertCallback, ?callable $
;
}

$containerBuilder->setDefinition(
'doctrine.dbal.a_connection.configuration',
(new Definition(Configuration::class))
->setMethodCalls([['setMiddlewares', [[new Reference('foo')]]]])
);
$containerBuilder->setDefinition('doctrine.dbal.b_connection.configuration', new Definition(Configuration::class));
$containerBuilder->setDefinition('doctrine.dbal.c_connection.configuration', new Definition(Configuration::class));

Expand All @@ -69,7 +66,8 @@ public function testProcess(array $config, callable $assertCallback, ?callable $
$expectationCallback($this, $containerBuilder);
}

(new DoctrineTestCompilerPass())->process($containerBuilder);
(new AddMiddlewaresCompilerPass())->process($containerBuilder);
(new ModifyDoctrineConfigCompilerPass())->process($containerBuilder);

foreach (array_keys($containerBuilder->getParameterBag()->all()) as $parameterName) {
$this->assertStringStartsNotWith('dama.', $parameterName);
Expand Down Expand Up @@ -97,35 +95,6 @@ function (ContainerBuilder $containerBuilder): void {
self::assertSame([
'dama.connection_key' => 'a',
], $containerBuilder->getDefinition('doctrine.dbal.a_connection')->getArgument(0));

self::assertEquals(
[
[
'setMiddlewares',
[
[
new Reference('dama.doctrine.dbal.middleware'),
new Reference('foo'),
],
],
],
],
$containerBuilder->getDefinition('doctrine.dbal.a_connection.configuration')->getMethodCalls()
);

self::assertEquals(
[
[
'setMiddlewares',
[
[
new Reference('dama.doctrine.dbal.middleware'),
],
],
],
],
$containerBuilder->getDefinition('doctrine.dbal.b_connection.configuration')->getMethodCalls()
);
},
];

Expand All @@ -137,20 +106,6 @@ function (ContainerBuilder $containerBuilder): void {
],
function (ContainerBuilder $containerBuilder): void {
self::assertFalse($containerBuilder->hasDefinition('doctrine.orm.a_metadata_cache'));

self::assertEquals(
[
[
'setMiddlewares',
[
[
new Reference('foo'),
],
],
],
],
$containerBuilder->getDefinition('doctrine.dbal.a_connection.configuration')->getMethodCalls()
);
},
];

Expand All @@ -167,8 +122,14 @@ function (ContainerBuilder $containerBuilder): void {
self::assertSame([
'dama.connection_key' => 'a',
], $containerBuilder->getDefinition('doctrine.dbal.a_connection')->getArgument(0));
self::assertTrue($containerBuilder->hasDefinition('dama.doctrine.dbal.middleware.a'));
self::assertSame([
'connection' => 'a',
'priority' => 100,
], $containerBuilder->getDefinition('dama.doctrine.dbal.middleware.a')->getTag('doctrine.middleware')[0]);

self::assertSame([], $containerBuilder->getDefinition('doctrine.dbal.b_connection')->getArgument(0));
self::assertFalse($containerBuilder->hasDefinition('dama.doctrine.dbal.middleware.b'));

self::assertSame(
[
Expand All @@ -187,6 +148,7 @@ function (ContainerBuilder $containerBuilder): void {
],
$containerBuilder->getDefinition('doctrine.dbal.c_connection')->getArgument(0)
);
self::assertTrue($containerBuilder->hasDefinition('dama.doctrine.dbal.middleware.c'));
},
];

Expand Down

0 comments on commit 1f81a28

Please sign in to comment.