From e2988017c23270acd95e25ec3289983ecc3895f7 Mon Sep 17 00:00:00 2001 From: Amanda Hernando <110099762+amanda-her@users.noreply.github.com> Date: Wed, 11 Oct 2023 01:36:01 +0200 Subject: [PATCH 1/7] feat(auth): add data platform instance field resolver provider (#8828) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Gómez Villamor Co-authored-by: Adrián Pertíñez --- .../authorization/ResolvedResourceSpec.java | 17 ++ .../authorization/ResourceFieldType.java | 6 +- .../DefaultResourceSpecResolver.java | 9 +- ...PlatformInstanceFieldResolverProvider.java | 70 +++++++ ...formInstanceFieldResolverProviderTest.java | 188 ++++++++++++++++++ 5 files changed, 286 insertions(+), 4 deletions(-) create mode 100644 metadata-service/auth-impl/src/main/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProvider.java create mode 100644 metadata-service/auth-impl/src/test/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProviderTest.java diff --git a/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResolvedResourceSpec.java b/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResolvedResourceSpec.java index 53dd0be44f963d..8e429a8ca1b944 100644 --- a/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResolvedResourceSpec.java +++ b/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResolvedResourceSpec.java @@ -3,6 +3,7 @@ import java.util.Collections; import java.util.Map; import java.util.Set; +import javax.annotation.Nullable; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; @@ -35,4 +36,20 @@ public Set getOwners() { } return fieldResolvers.get(ResourceFieldType.OWNER).getFieldValuesFuture().join().getValues(); } + + /** + * Fetch the platform instance for a Resolved Resource Spec + * @return a Platform Instance or null if one does not exist. + */ + @Nullable + public String getDataPlatformInstance() { + if (!fieldResolvers.containsKey(ResourceFieldType.DATA_PLATFORM_INSTANCE)) { + return null; + } + Set dataPlatformInstance = fieldResolvers.get(ResourceFieldType.DATA_PLATFORM_INSTANCE).getFieldValuesFuture().join().getValues(); + if (dataPlatformInstance.size() > 0) { + return dataPlatformInstance.stream().findFirst().get(); + } + return null; + } } diff --git a/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResourceFieldType.java b/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResourceFieldType.java index ee54d2bfbba1da..478522dc7c331c 100644 --- a/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResourceFieldType.java +++ b/metadata-auth/auth-api/src/main/java/com/datahub/authorization/ResourceFieldType.java @@ -19,5 +19,9 @@ public enum ResourceFieldType { /** * Domains of resource */ - DOMAIN + DOMAIN, + /** + * Data platform instance of resource + */ + DATA_PLATFORM_INSTANCE } diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authorization/DefaultResourceSpecResolver.java b/metadata-service/auth-impl/src/main/java/com/datahub/authorization/DefaultResourceSpecResolver.java index cd4e0b09678296..64c43dc8aa591a 100644 --- a/metadata-service/auth-impl/src/main/java/com/datahub/authorization/DefaultResourceSpecResolver.java +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authorization/DefaultResourceSpecResolver.java @@ -1,13 +1,15 @@ package com.datahub.authorization; -import com.datahub.authorization.fieldresolverprovider.EntityTypeFieldResolverProvider; -import com.datahub.authorization.fieldresolverprovider.OwnerFieldResolverProvider; import com.datahub.authentication.Authentication; +import com.datahub.authorization.fieldresolverprovider.DataPlatformInstanceFieldResolverProvider; import com.datahub.authorization.fieldresolverprovider.DomainFieldResolverProvider; +import com.datahub.authorization.fieldresolverprovider.EntityTypeFieldResolverProvider; import com.datahub.authorization.fieldresolverprovider.EntityUrnFieldResolverProvider; +import com.datahub.authorization.fieldresolverprovider.OwnerFieldResolverProvider; import com.datahub.authorization.fieldresolverprovider.ResourceFieldResolverProvider; import com.google.common.collect.ImmutableList; import com.linkedin.entity.client.EntityClient; + import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -20,7 +22,8 @@ public DefaultResourceSpecResolver(Authentication systemAuthentication, EntityCl _resourceFieldResolverProviders = ImmutableList.of(new EntityTypeFieldResolverProvider(), new EntityUrnFieldResolverProvider(), new DomainFieldResolverProvider(entityClient, systemAuthentication), - new OwnerFieldResolverProvider(entityClient, systemAuthentication)); + new OwnerFieldResolverProvider(entityClient, systemAuthentication), + new DataPlatformInstanceFieldResolverProvider(entityClient, systemAuthentication)); } @Override diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProvider.java b/metadata-service/auth-impl/src/main/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProvider.java new file mode 100644 index 00000000000000..cd838625c2ca1f --- /dev/null +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProvider.java @@ -0,0 +1,70 @@ +package com.datahub.authorization.fieldresolverprovider; + +import com.datahub.authentication.Authentication; +import com.datahub.authorization.FieldResolver; +import com.datahub.authorization.ResourceFieldType; +import com.datahub.authorization.ResourceSpec; +import com.linkedin.common.DataPlatformInstance; +import com.linkedin.common.urn.Urn; +import com.linkedin.common.urn.UrnUtils; +import com.linkedin.entity.EntityResponse; +import com.linkedin.entity.EnvelopedAspect; +import com.linkedin.entity.client.EntityClient; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import java.util.Collections; +import java.util.Objects; + +import static com.linkedin.metadata.Constants.*; + +/** + * Provides field resolver for domain given resourceSpec + */ +@Slf4j +@RequiredArgsConstructor +public class DataPlatformInstanceFieldResolverProvider implements ResourceFieldResolverProvider { + + private final EntityClient _entityClient; + private final Authentication _systemAuthentication; + + @Override + public ResourceFieldType getFieldType() { + return ResourceFieldType.DATA_PLATFORM_INSTANCE; + } + + @Override + public FieldResolver getFieldResolver(ResourceSpec resourceSpec) { + return FieldResolver.getResolverFromFunction(resourceSpec, this::getDataPlatformInstance); + } + + private FieldResolver.FieldValue getDataPlatformInstance(ResourceSpec resourceSpec) { + Urn entityUrn = UrnUtils.getUrn(resourceSpec.getResource()); + // In the case that the entity is a platform instance, the associated platform instance entity is the instance itself + if (entityUrn.getEntityType().equals(DATA_PLATFORM_INSTANCE_ENTITY_NAME)) { + return FieldResolver.FieldValue.builder() + .values(Collections.singleton(entityUrn.toString())) + .build(); + } + + EnvelopedAspect dataPlatformInstanceAspect; + try { + EntityResponse response = _entityClient.getV2(entityUrn.getEntityType(), entityUrn, + Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME), _systemAuthentication); + if (response == null || !response.getAspects().containsKey(DATA_PLATFORM_INSTANCE_ASPECT_NAME)) { + return FieldResolver.emptyFieldValue(); + } + dataPlatformInstanceAspect = response.getAspects().get(DATA_PLATFORM_INSTANCE_ASPECT_NAME); + } catch (Exception e) { + log.error("Error while retrieving platform instance aspect for urn {}", entityUrn, e); + return FieldResolver.emptyFieldValue(); + } + DataPlatformInstance dataPlatformInstance = new DataPlatformInstance(dataPlatformInstanceAspect.getValue().data()); + if (dataPlatformInstance.getInstance() == null) { + return FieldResolver.emptyFieldValue(); + } + return FieldResolver.FieldValue.builder() + .values(Collections.singleton(Objects.requireNonNull(dataPlatformInstance.getInstance()).toString())) + .build(); + } +} \ No newline at end of file diff --git a/metadata-service/auth-impl/src/test/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProviderTest.java b/metadata-service/auth-impl/src/test/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProviderTest.java new file mode 100644 index 00000000000000..e525c602c26206 --- /dev/null +++ b/metadata-service/auth-impl/src/test/java/com/datahub/authorization/fieldresolverprovider/DataPlatformInstanceFieldResolverProviderTest.java @@ -0,0 +1,188 @@ +package com.datahub.authorization.fieldresolverprovider; + +import com.datahub.authentication.Authentication; +import com.datahub.authorization.ResourceFieldType; +import com.datahub.authorization.ResourceSpec; +import com.linkedin.common.DataPlatformInstance; +import com.linkedin.common.urn.Urn; +import com.linkedin.entity.Aspect; +import com.linkedin.entity.EntityResponse; +import com.linkedin.entity.EnvelopedAspect; +import com.linkedin.entity.EnvelopedAspectMap; +import com.linkedin.entity.client.EntityClient; +import com.linkedin.r2.RemoteInvocationException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.net.URISyntaxException; +import java.util.Collections; +import java.util.Set; + +import static com.linkedin.metadata.Constants.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class DataPlatformInstanceFieldResolverProviderTest { + + private static final String DATA_PLATFORM_INSTANCE_URN = + "urn:li:dataPlatformInstance:(urn:li:dataPlatform:s3,test-platform-instance)"; + private static final String RESOURCE_URN = + "urn:li:dataset:(urn:li:dataPlatform:s3,test-platform-instance.testDataset,PROD)"; + private static final ResourceSpec RESOURCE_SPEC = new ResourceSpec(DATASET_ENTITY_NAME, RESOURCE_URN); + + @Mock + private EntityClient entityClientMock; + @Mock + private Authentication systemAuthenticationMock; + + private DataPlatformInstanceFieldResolverProvider dataPlatformInstanceFieldResolverProvider; + + @BeforeMethod + public void setup() { + MockitoAnnotations.initMocks(this); + dataPlatformInstanceFieldResolverProvider = + new DataPlatformInstanceFieldResolverProvider(entityClientMock, systemAuthenticationMock); + } + + @Test + public void shouldReturnDataPlatformInstanceType() { + assertEquals(ResourceFieldType.DATA_PLATFORM_INSTANCE, dataPlatformInstanceFieldResolverProvider.getFieldType()); + } + + @Test + public void shouldReturnFieldValueWithResourceSpecIfTypeIsDataPlatformInstance() { + var resourceSpec = new ResourceSpec(DATA_PLATFORM_INSTANCE_ENTITY_NAME, DATA_PLATFORM_INSTANCE_URN); + + var result = dataPlatformInstanceFieldResolverProvider.getFieldResolver(resourceSpec); + + assertEquals(Set.of(DATA_PLATFORM_INSTANCE_URN), result.getFieldValuesFuture().join().getValues()); + verifyZeroInteractions(entityClientMock); + } + + @Test + public void shouldReturnEmptyFieldValueWhenResponseIsNull() throws RemoteInvocationException, URISyntaxException { + when(entityClientMock.getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + )).thenReturn(null); + + var result = dataPlatformInstanceFieldResolverProvider.getFieldResolver(RESOURCE_SPEC); + + assertTrue(result.getFieldValuesFuture().join().getValues().isEmpty()); + verify(entityClientMock, times(1)).getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + ); + } + + @Test + public void shouldReturnEmptyFieldValueWhenResourceHasNoDataPlatformInstance() + throws RemoteInvocationException, URISyntaxException { + var entityResponseMock = mock(EntityResponse.class); + when(entityResponseMock.getAspects()).thenReturn(new EnvelopedAspectMap()); + when(entityClientMock.getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + )).thenReturn(entityResponseMock); + + var result = dataPlatformInstanceFieldResolverProvider.getFieldResolver(RESOURCE_SPEC); + + assertTrue(result.getFieldValuesFuture().join().getValues().isEmpty()); + verify(entityClientMock, times(1)).getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + ); + } + + @Test + public void shouldReturnEmptyFieldValueWhenThereIsAnException() throws RemoteInvocationException, URISyntaxException { + when(entityClientMock.getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + )).thenThrow(new RemoteInvocationException()); + + var result = dataPlatformInstanceFieldResolverProvider.getFieldResolver(RESOURCE_SPEC); + + assertTrue(result.getFieldValuesFuture().join().getValues().isEmpty()); + verify(entityClientMock, times(1)).getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + ); + } + + @Test + public void shouldReturnEmptyFieldValueWhenDataPlatformInstanceHasNoInstance() + throws RemoteInvocationException, URISyntaxException { + + var dataPlatform = new DataPlatformInstance() + .setPlatform(Urn.createFromString("urn:li:dataPlatform:s3")); + var entityResponseMock = mock(EntityResponse.class); + var envelopedAspectMap = new EnvelopedAspectMap(); + envelopedAspectMap.put(DATA_PLATFORM_INSTANCE_ASPECT_NAME, + new EnvelopedAspect().setValue(new Aspect(dataPlatform.data()))); + when(entityResponseMock.getAspects()).thenReturn(envelopedAspectMap); + when(entityClientMock.getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + )).thenReturn(entityResponseMock); + + var result = dataPlatformInstanceFieldResolverProvider.getFieldResolver(RESOURCE_SPEC); + + assertTrue(result.getFieldValuesFuture().join().getValues().isEmpty()); + verify(entityClientMock, times(1)).getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + ); + } + + @Test + public void shouldReturnFieldValueWithDataPlatformInstanceOfTheResource() + throws RemoteInvocationException, URISyntaxException { + + var dataPlatformInstance = new DataPlatformInstance() + .setPlatform(Urn.createFromString("urn:li:dataPlatform:s3")) + .setInstance(Urn.createFromString(DATA_PLATFORM_INSTANCE_URN)); + var entityResponseMock = mock(EntityResponse.class); + var envelopedAspectMap = new EnvelopedAspectMap(); + envelopedAspectMap.put(DATA_PLATFORM_INSTANCE_ASPECT_NAME, + new EnvelopedAspect().setValue(new Aspect(dataPlatformInstance.data()))); + when(entityResponseMock.getAspects()).thenReturn(envelopedAspectMap); + when(entityClientMock.getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + )).thenReturn(entityResponseMock); + + var result = dataPlatformInstanceFieldResolverProvider.getFieldResolver(RESOURCE_SPEC); + + assertEquals(Set.of(DATA_PLATFORM_INSTANCE_URN), result.getFieldValuesFuture().join().getValues()); + verify(entityClientMock, times(1)).getV2( + eq(DATASET_ENTITY_NAME), + any(Urn.class), + eq(Collections.singleton(DATA_PLATFORM_INSTANCE_ASPECT_NAME)), + eq(systemAuthenticationMock) + ); + } +} From a17db676e37d90ec47f16a43ab95e0d562952939 Mon Sep 17 00:00:00 2001 From: siladitya <68184387+siladitya2@users.noreply.github.com> Date: Wed, 11 Oct 2023 02:43:36 +0200 Subject: [PATCH 2/7] feat(graphql): Added datafetcher for DataPlatformInstance entity (#8935) Co-authored-by: si-chakraborty Co-authored-by: John Joyce --- .../datahub/graphql/GmsGraphQLEngine.java | 1 + .../DataPlatformInstanceType.java | 34 ++++++++++++++++++- .../src/main/resources/entity.graphql | 5 +++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/GmsGraphQLEngine.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/GmsGraphQLEngine.java index 3ba0cc1f747e30..ebb5c7d62c7d3a 100644 --- a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/GmsGraphQLEngine.java +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/GmsGraphQLEngine.java @@ -821,6 +821,7 @@ private void configureQueryResolvers(final RuntimeWiring.Builder builder) { .dataFetcher("glossaryNode", getResolver(glossaryNodeType)) .dataFetcher("domain", getResolver((domainType))) .dataFetcher("dataPlatform", getResolver(dataPlatformType)) + .dataFetcher("dataPlatformInstance", getResolver(dataPlatformInstanceType)) .dataFetcher("mlFeatureTable", getResolver(mlFeatureTableType)) .dataFetcher("mlFeature", getResolver(mlFeatureType)) .dataFetcher("mlPrimaryKey", getResolver(mlPrimaryKeyType)) diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataplatforminstance/DataPlatformInstanceType.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataplatforminstance/DataPlatformInstanceType.java index 2423fc31ea52e3..87614e13325283 100644 --- a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataplatforminstance/DataPlatformInstanceType.java +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataplatforminstance/DataPlatformInstanceType.java @@ -4,16 +4,25 @@ import com.linkedin.common.urn.Urn; import com.linkedin.common.urn.UrnUtils; import com.linkedin.datahub.graphql.QueryContext; +import com.linkedin.datahub.graphql.generated.AutoCompleteResults; import com.linkedin.datahub.graphql.generated.DataPlatformInstance; import com.linkedin.datahub.graphql.generated.Entity; import com.linkedin.datahub.graphql.generated.EntityType; +import com.linkedin.datahub.graphql.generated.FacetFilterInput; +import com.linkedin.datahub.graphql.generated.SearchResults; import com.linkedin.datahub.graphql.types.dataplatforminstance.mappers.DataPlatformInstanceMapper; +import com.linkedin.datahub.graphql.types.mappers.AutoCompleteResultsMapper; +import com.linkedin.datahub.graphql.types.SearchableEntityType; import com.linkedin.entity.EntityResponse; import com.linkedin.entity.client.EntityClient; import com.linkedin.metadata.Constants; +import com.linkedin.metadata.query.AutoCompleteResult; +import com.linkedin.metadata.query.filter.Filter; import graphql.execution.DataFetcherResult; +import org.apache.commons.lang3.NotImplementedException; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -22,7 +31,10 @@ import java.util.function.Function; import java.util.stream.Collectors; -public class DataPlatformInstanceType implements com.linkedin.datahub.graphql.types.EntityType { +import static com.linkedin.metadata.Constants.DATA_PLATFORM_INSTANCE_ENTITY_NAME; + +public class DataPlatformInstanceType implements SearchableEntityType, + com.linkedin.datahub.graphql.types.EntityType { static final Set ASPECTS_TO_FETCH = ImmutableSet.of( Constants.DATA_PLATFORM_INSTANCE_KEY_ASPECT_NAME, @@ -84,4 +96,24 @@ public List> batchLoad(@Nonnull List filters, + int start, + int count, + @Nonnull final QueryContext context) throws Exception { + throw new NotImplementedException("Searchable type (deprecated) not implemented on DataPlatformInstance entity type"); + } + + @Override + public AutoCompleteResults autoComplete(@Nonnull String query, + @Nullable String field, + @Nullable Filter filters, + int limit, + @Nonnull final QueryContext context) throws Exception { + final AutoCompleteResult result = _entityClient.autoComplete(DATA_PLATFORM_INSTANCE_ENTITY_NAME, query, + filters, limit, context.getAuthentication()); + return AutoCompleteResultsMapper.map(result); + } + } diff --git a/datahub-graphql-core/src/main/resources/entity.graphql b/datahub-graphql-core/src/main/resources/entity.graphql index 39f86948c77c40..0b15d7b875a9ca 100644 --- a/datahub-graphql-core/src/main/resources/entity.graphql +++ b/datahub-graphql-core/src/main/resources/entity.graphql @@ -226,6 +226,11 @@ type Query { listOwnershipTypes( "Input required for listing custom ownership types" input: ListOwnershipTypesInput!): ListOwnershipTypesResult! + + """ + Fetch a Data Platform Instance by primary key (urn) + """ + dataPlatformInstance(urn: String!): DataPlatformInstance } """ From dfcea2441e75e1eef517c0f9a4765e6e7990f297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20G=C3=B3mez=20Villamor?= Date: Wed, 11 Oct 2023 03:04:44 +0200 Subject: [PATCH 3/7] feat(config): configurable bootstrap policies file (#8812) Co-authored-by: John Joyce --- .../configuration/src/main/resources/application.yml | 4 ++++ .../boot/factories/BootstrapManagerFactory.java | 7 ++++++- .../metadata/boot/steps/IngestPoliciesStep.java | 10 +++++++--- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/metadata-service/configuration/src/main/resources/application.yml b/metadata-service/configuration/src/main/resources/application.yml index 4dfd96ac75c6ce..d22f92adca8f9c 100644 --- a/metadata-service/configuration/src/main/resources/application.yml +++ b/metadata-service/configuration/src/main/resources/application.yml @@ -276,6 +276,10 @@ bootstrap: enabled: ${UPGRADE_DEFAULT_BROWSE_PATHS_ENABLED:false} # enable to run the upgrade to migrate legacy default browse paths to new ones backfillBrowsePathsV2: enabled: ${BACKFILL_BROWSE_PATHS_V2:false} # Enables running the backfill of browsePathsV2 upgrade step. There are concerns about the load of this step so hiding it behind a flag. Deprecating in favor of running through SystemUpdate + policies: + file: ${BOOTSTRAP_POLICIES_FILE:classpath:boot/policies.json} + # eg for local file + # file: "file:///datahub/datahub-gms/resources/custom-policies.json" servlets: waitTimeout: ${BOOTSTRAP_SERVLETS_WAITTIMEOUT:60} # Total waiting time in seconds for servlets to initialize diff --git a/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/factories/BootstrapManagerFactory.java b/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/factories/BootstrapManagerFactory.java index c490f000212010..3a761bd12647e6 100644 --- a/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/factories/BootstrapManagerFactory.java +++ b/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/factories/BootstrapManagerFactory.java @@ -31,6 +31,7 @@ import com.linkedin.metadata.search.EntitySearchService; import com.linkedin.metadata.search.SearchService; import com.linkedin.metadata.search.transformer.SearchDocumentTransformer; + import java.util.ArrayList; import java.util.List; import javax.annotation.Nonnull; @@ -41,6 +42,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Scope; +import org.springframework.core.io.Resource; @Configuration @@ -89,13 +91,16 @@ public class BootstrapManagerFactory { @Value("${bootstrap.backfillBrowsePathsV2.enabled}") private Boolean _backfillBrowsePathsV2Enabled; + @Value("${bootstrap.policies.file}") + private Resource _policiesResource; + @Bean(name = "bootstrapManager") @Scope("singleton") @Nonnull protected BootstrapManager createInstance() { final IngestRootUserStep ingestRootUserStep = new IngestRootUserStep(_entityService); final IngestPoliciesStep ingestPoliciesStep = - new IngestPoliciesStep(_entityRegistry, _entityService, _entitySearchService, _searchDocumentTransformer); + new IngestPoliciesStep(_entityRegistry, _entityService, _entitySearchService, _searchDocumentTransformer, _policiesResource); final IngestRolesStep ingestRolesStep = new IngestRolesStep(_entityService, _entityRegistry); final IngestDataPlatformsStep ingestDataPlatformsStep = new IngestDataPlatformsStep(_entityService); final IngestDataPlatformInstancesStep ingestDataPlatformInstancesStep = diff --git a/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/steps/IngestPoliciesStep.java b/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/steps/IngestPoliciesStep.java index 87dcfd736da401..cf296452144664 100644 --- a/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/steps/IngestPoliciesStep.java +++ b/metadata-service/factories/src/main/java/com/linkedin/metadata/boot/steps/IngestPoliciesStep.java @@ -25,6 +25,7 @@ import com.linkedin.mxe.GenericAspect; import com.linkedin.mxe.MetadataChangeProposal; import com.linkedin.policy.DataHubPolicyInfo; + import java.io.IOException; import java.net.URISyntaxException; import java.util.Collections; @@ -35,7 +36,8 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; + import static com.linkedin.metadata.Constants.*; @@ -52,6 +54,8 @@ public class IngestPoliciesStep implements BootstrapStep { private final EntitySearchService _entitySearchService; private final SearchDocumentTransformer _searchDocumentTransformer; + private final Resource _policiesResource; + @Override public String name() { return "IngestPoliciesStep"; @@ -66,10 +70,10 @@ public void execute() throws IOException, URISyntaxException { .maxStringLength(maxSize).build()); // 0. Execute preflight check to see whether we need to ingest policies - log.info("Ingesting default access policies..."); + log.info("Ingesting default access policies from: {}...", _policiesResource); // 1. Read from the file into JSON. - final JsonNode policiesObj = mapper.readTree(new ClassPathResource("./boot/policies.json").getFile()); + final JsonNode policiesObj = mapper.readTree(_policiesResource.getFile()); if (!policiesObj.isArray()) { throw new RuntimeException( From 10a190470e8c932b6d34cba49de7dbcba687a088 Mon Sep 17 00:00:00 2001 From: siddiquebagwan-gslab Date: Wed, 11 Oct 2023 08:54:08 +0530 Subject: [PATCH 4/7] feat(ingestion/redshift): CLL support in redshift (#8921) --- .../ingestion/source/redshift/config.py | 4 + .../ingestion/source/redshift/lineage.py | 215 +++++++++++++----- .../ingestion/source/redshift/redshift.py | 1 + .../tests/unit/test_redshift_lineage.py | 95 ++++++-- 4 files changed, 234 insertions(+), 81 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py index 804a14b0fe1cfb..2789b800940db2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py @@ -132,6 +132,10 @@ class RedshiftConfig( description="Whether `schema_pattern` is matched against fully qualified schema name `.`.", ) + extract_column_level_lineage: bool = Field( + default=True, description="Whether to extract column level lineage." + ) + @root_validator(pre=True) def check_email_is_set_on_usage(cls, values): if values.get("include_usage_statistics"): diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py index bbe52b5d98ba36..c9ddfbe92ab2ab 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py @@ -9,10 +9,12 @@ import humanfriendly import redshift_connector -from sqllineage.runner import LineageRunner +import datahub.emitter.mce_builder as builder +import datahub.utilities.sqlglot_lineage as sqlglot_l from datahub.emitter import mce_builder from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance +from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.source.aws.s3_util import strip_s3_prefix from datahub.ingestion.source.redshift.common import get_db_name from datahub.ingestion.source.redshift.config import LineageMode, RedshiftConfig @@ -28,13 +30,19 @@ from datahub.ingestion.source.state.redundant_run_skip_handler import ( RedundantLineageRunSkipHandler, ) -from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage +from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( + FineGrainedLineage, + FineGrainedLineageDownstreamType, + FineGrainedLineageUpstreamType, + UpstreamLineage, +) from datahub.metadata.schema_classes import ( DatasetLineageTypeClass, UpstreamClass, UpstreamLineageClass, ) from datahub.utilities import memory_footprint +from datahub.utilities.urns import dataset_urn logger: logging.Logger = logging.getLogger(__name__) @@ -56,13 +64,14 @@ class LineageCollectorType(Enum): @dataclass(frozen=True, eq=True) class LineageDataset: platform: LineageDatasetPlatform - path: str + urn: str @dataclass() class LineageItem: dataset: LineageDataset upstreams: Set[LineageDataset] + cll: Optional[List[sqlglot_l.ColumnLineageInfo]] collector_type: LineageCollectorType dataset_lineage_type: str = field(init=False) @@ -83,10 +92,12 @@ def __init__( self, config: RedshiftConfig, report: RedshiftReport, + context: PipelineContext, redundant_run_skip_handler: Optional[RedundantLineageRunSkipHandler] = None, ): self.config = config self.report = report + self.context = context self._lineage_map: Dict[str, LineageItem] = defaultdict() self.redundant_run_skip_handler = redundant_run_skip_handler @@ -121,33 +132,37 @@ def _get_s3_path(self, path: str) -> str: return path - def _get_sources_from_query(self, db_name: str, query: str) -> List[LineageDataset]: + def _get_sources_from_query( + self, db_name: str, query: str + ) -> Tuple[List[LineageDataset], Optional[List[sqlglot_l.ColumnLineageInfo]]]: sources: List[LineageDataset] = list() - parser = LineageRunner(query) + parsed_result: Optional[ + sqlglot_l.SqlParsingResult + ] = sqlglot_l.create_lineage_sql_parsed_result( + query=query, + platform=LineageDatasetPlatform.REDSHIFT.value, + platform_instance=self.config.platform_instance, + database=db_name, + schema=str(self.config.default_schema), + graph=self.context.graph, + env=self.config.env, + ) - for table in parser.source_tables: - split = str(table).split(".") - if len(split) == 3: - db_name, source_schema, source_table = split - elif len(split) == 2: - source_schema, source_table = split - else: - raise ValueError( - f"Invalid table name {table} in query {query}. " - f"Expected format: [db_name].[schema].[table] or [schema].[table] or [table]." - ) + if parsed_result is None: + logger.debug(f"native query parsing failed for {query}") + return sources, None - if source_schema == "": - source_schema = str(self.config.default_schema) + logger.debug(f"parsed_result = {parsed_result}") + for table_urn in parsed_result.in_tables: source = LineageDataset( platform=LineageDatasetPlatform.REDSHIFT, - path=f"{db_name}.{source_schema}.{source_table}", + urn=table_urn, ) sources.append(source) - return sources + return sources, parsed_result.column_lineage def _build_s3_path_from_row(self, filename: str) -> str: path = filename.strip() @@ -165,9 +180,11 @@ def _get_sources( source_table: Optional[str], ddl: Optional[str], filename: Optional[str], - ) -> List[LineageDataset]: + ) -> Tuple[List[LineageDataset], Optional[List[sqlglot_l.ColumnLineageInfo]]]: sources: List[LineageDataset] = list() # Source + cll: Optional[List[sqlglot_l.ColumnLineageInfo]] = None + if ( lineage_type in { @@ -177,7 +194,7 @@ def _get_sources( and ddl is not None ): try: - sources = self._get_sources_from_query(db_name=db_name, query=ddl) + sources, cll = self._get_sources_from_query(db_name=db_name, query=ddl) except Exception as e: logger.warning( f"Error parsing query {ddl} for getting lineage. Error was {e}." @@ -192,22 +209,38 @@ def _get_sources( "Only s3 source supported with copy. The source was: {path}." ) self.report.num_lineage_dropped_not_support_copy_path += 1 - return sources + return sources, cll path = strip_s3_prefix(self._get_s3_path(path)) + urn = make_dataset_urn_with_platform_instance( + platform=platform.value, + name=path, + env=self.config.env, + platform_instance=self.config.platform_instance_map.get( + platform.value + ) + if self.config.platform_instance_map is not None + else None, + ) elif source_schema is not None and source_table is not None: platform = LineageDatasetPlatform.REDSHIFT path = f"{db_name}.{source_schema}.{source_table}" + urn = make_dataset_urn_with_platform_instance( + platform=platform.value, + platform_instance=self.config.platform_instance, + name=path, + env=self.config.env, + ) else: - return [] + return [], cll sources = [ LineageDataset( platform=platform, - path=path, + urn=urn, ) ] - return sources + return sources, cll def _populate_lineage_map( self, @@ -231,6 +264,7 @@ def _populate_lineage_map( :rtype: None """ try: + cll: Optional[List[sqlglot_l.ColumnLineageInfo]] = None raw_db_name = database alias_db_name = get_db_name(self.config) @@ -243,7 +277,7 @@ def _populate_lineage_map( if not target: continue - sources = self._get_sources( + sources, cll = self._get_sources( lineage_type, alias_db_name, source_schema=lineage_row.source_schema, @@ -251,6 +285,7 @@ def _populate_lineage_map( ddl=lineage_row.ddl, filename=lineage_row.filename, ) + target.cll = cll target.upstreams.update( self._get_upstream_lineages( @@ -262,20 +297,16 @@ def _populate_lineage_map( ) # Merging downstreams if dataset already exists and has downstreams - if target.dataset.path in self._lineage_map: - self._lineage_map[ - target.dataset.path - ].upstreams = self._lineage_map[ - target.dataset.path - ].upstreams.union( - target.upstreams - ) + if target.dataset.urn in self._lineage_map: + self._lineage_map[target.dataset.urn].upstreams = self._lineage_map[ + target.dataset.urn + ].upstreams.union(target.upstreams) else: - self._lineage_map[target.dataset.path] = target + self._lineage_map[target.dataset.urn] = target logger.debug( - f"Lineage[{target}]:{self._lineage_map[target.dataset.path]}" + f"Lineage[{target}]:{self._lineage_map[target.dataset.urn]}" ) except Exception as e: self.warn( @@ -308,17 +339,34 @@ def _get_target_lineage( target_platform = LineageDatasetPlatform.S3 # Following call requires 'filename' key in lineage_row target_path = self._build_s3_path_from_row(lineage_row.filename) + urn = make_dataset_urn_with_platform_instance( + platform=target_platform.value, + name=target_path, + env=self.config.env, + platform_instance=self.config.platform_instance_map.get( + target_platform.value + ) + if self.config.platform_instance_map is not None + else None, + ) except ValueError as e: self.warn(logger, "non-s3-lineage", str(e)) return None else: target_platform = LineageDatasetPlatform.REDSHIFT target_path = f"{alias_db_name}.{lineage_row.target_schema}.{lineage_row.target_table}" + urn = make_dataset_urn_with_platform_instance( + platform=target_platform.value, + platform_instance=self.config.platform_instance, + name=target_path, + env=self.config.env, + ) return LineageItem( - dataset=LineageDataset(platform=target_platform, path=target_path), + dataset=LineageDataset(platform=target_platform, urn=urn), upstreams=set(), collector_type=lineage_type, + cll=None, ) def _get_upstream_lineages( @@ -331,11 +379,22 @@ def _get_upstream_lineages( targe_source = [] for source in sources: if source.platform == LineageDatasetPlatform.REDSHIFT: - db, schema, table = source.path.split(".") + qualified_table_name = dataset_urn.DatasetUrn.create_from_string( + source.urn + ).get_entity_id()[1] + db, schema, table = qualified_table_name.split(".") if db == raw_db_name: db = alias_db_name path = f"{db}.{schema}.{table}" - source = LineageDataset(platform=source.platform, path=path) + source = LineageDataset( + platform=source.platform, + urn=make_dataset_urn_with_platform_instance( + platform=LineageDatasetPlatform.REDSHIFT.value, + platform_instance=self.config.platform_instance, + name=path, + env=self.config.env, + ), + ) # Filtering out tables which does not exist in Redshift # It was deleted in the meantime or query parser did not capture well the table name @@ -345,7 +404,7 @@ def _get_upstream_lineages( or not any(table == t.name for t in all_tables[db][schema]) ): logger.debug( - f"{source.path} missing table, dropping from lineage.", + f"{source.urn} missing table, dropping from lineage.", ) self.report.num_lineage_tables_dropped += 1 continue @@ -433,36 +492,73 @@ def populate_lineage( memory_footprint.total_size(self._lineage_map) ) + def make_fine_grained_lineage_class( + self, lineage_item: LineageItem, dataset_urn: str + ) -> List[FineGrainedLineage]: + fine_grained_lineages: List[FineGrainedLineage] = [] + + if ( + self.config.extract_column_level_lineage is False + or lineage_item.cll is None + ): + logger.debug("CLL extraction is disabled") + return fine_grained_lineages + + logger.debug("Extracting column level lineage") + + cll: List[sqlglot_l.ColumnLineageInfo] = lineage_item.cll + + for cll_info in cll: + downstream = ( + [builder.make_schema_field_urn(dataset_urn, cll_info.downstream.column)] + if cll_info.downstream is not None + and cll_info.downstream.column is not None + else [] + ) + + upstreams = [ + builder.make_schema_field_urn(column_ref.table, column_ref.column) + for column_ref in cll_info.upstreams + ] + + fine_grained_lineages.append( + FineGrainedLineage( + downstreamType=FineGrainedLineageDownstreamType.FIELD, + downstreams=downstream, + upstreamType=FineGrainedLineageUpstreamType.FIELD_SET, + upstreams=upstreams, + ) + ) + + logger.debug(f"Created fine_grained_lineage for {dataset_urn}") + + return fine_grained_lineages + def get_lineage( self, table: Union[RedshiftTable, RedshiftView], dataset_urn: str, schema: RedshiftSchema, ) -> Optional[Tuple[UpstreamLineageClass, Dict[str, str]]]: - dataset_key = mce_builder.dataset_urn_to_key(dataset_urn) - if dataset_key is None: - return None upstream_lineage: List[UpstreamClass] = [] - if dataset_key.name in self._lineage_map: - item = self._lineage_map[dataset_key.name] + cll_lineage: List[FineGrainedLineage] = [] + + if dataset_urn in self._lineage_map: + item = self._lineage_map[dataset_urn] for upstream in item.upstreams: upstream_table = UpstreamClass( - dataset=make_dataset_urn_with_platform_instance( - upstream.platform.value, - upstream.path, - platform_instance=self.config.platform_instance_map.get( - upstream.platform.value - ) - if self.config.platform_instance_map - else None, - env=self.config.env, - ), + dataset=upstream.urn, type=item.dataset_lineage_type, ) upstream_lineage.append(upstream_table) + cll_lineage = self.make_fine_grained_lineage_class( + lineage_item=item, + dataset_urn=dataset_urn, + ) + tablename = table.name if table.type == "EXTERNAL_TABLE": # external_db_params = schema.option @@ -489,7 +585,12 @@ def get_lineage( else: return None - return UpstreamLineage(upstreams=upstream_lineage), {} + return ( + UpstreamLineage( + upstreams=upstream_lineage, fineGrainedLineages=cll_lineage or None + ), + {}, + ) def report_status(self, step: str, status: bool) -> None: if self.redundant_run_skip_handler: diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py index e8a8ff976afa6c..a1b6333a3775d4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py @@ -881,6 +881,7 @@ def extract_lineage( self.lineage_extractor = RedshiftLineageExtractor( config=self.config, report=self.report, + context=self.ctx, redundant_run_skip_handler=self.redundant_lineage_run_skip_handler, ) diff --git a/metadata-ingestion/tests/unit/test_redshift_lineage.py b/metadata-ingestion/tests/unit/test_redshift_lineage.py index c7d6ac18e044cb..db5af3a71efb99 100644 --- a/metadata-ingestion/tests/unit/test_redshift_lineage.py +++ b/metadata-ingestion/tests/unit/test_redshift_lineage.py @@ -1,6 +1,8 @@ +from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.source.redshift.config import RedshiftConfig from datahub.ingestion.source.redshift.lineage import RedshiftLineageExtractor from datahub.ingestion.source.redshift.report import RedshiftReport +from datahub.utilities.sqlglot_lineage import ColumnLineageInfo, DownstreamColumnRef def test_get_sources_from_query(): @@ -10,14 +12,20 @@ def test_get_sources_from_query(): test_query = """ select * from my_schema.my_table """ - lineage_extractor = RedshiftLineageExtractor(config, report) - lineage_datasets = lineage_extractor._get_sources_from_query( + lineage_extractor = RedshiftLineageExtractor( + config, report, PipelineContext(run_id="foo") + ) + lineage_datasets, _ = lineage_extractor._get_sources_from_query( db_name="test", query=test_query ) assert len(lineage_datasets) == 1 lineage = lineage_datasets[0] - assert lineage.path == "test.my_schema.my_table" + + assert ( + lineage.urn + == "urn:li:dataset:(urn:li:dataPlatform:redshift,test.my_schema.my_table,PROD)" + ) def test_get_sources_from_query_with_only_table_name(): @@ -27,14 +35,20 @@ def test_get_sources_from_query_with_only_table_name(): test_query = """ select * from my_table """ - lineage_extractor = RedshiftLineageExtractor(config, report) - lineage_datasets = lineage_extractor._get_sources_from_query( + lineage_extractor = RedshiftLineageExtractor( + config, report, PipelineContext(run_id="foo") + ) + lineage_datasets, _ = lineage_extractor._get_sources_from_query( db_name="test", query=test_query ) assert len(lineage_datasets) == 1 lineage = lineage_datasets[0] - assert lineage.path == "test.public.my_table" + + assert ( + lineage.urn + == "urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.my_table,PROD)" + ) def test_get_sources_from_query_with_database(): @@ -44,14 +58,20 @@ def test_get_sources_from_query_with_database(): test_query = """ select * from test.my_schema.my_table """ - lineage_extractor = RedshiftLineageExtractor(config, report) - lineage_datasets = lineage_extractor._get_sources_from_query( + lineage_extractor = RedshiftLineageExtractor( + config, report, PipelineContext(run_id="foo") + ) + lineage_datasets, _ = lineage_extractor._get_sources_from_query( db_name="test", query=test_query ) assert len(lineage_datasets) == 1 lineage = lineage_datasets[0] - assert lineage.path == "test.my_schema.my_table" + + assert ( + lineage.urn + == "urn:li:dataset:(urn:li:dataPlatform:redshift,test.my_schema.my_table,PROD)" + ) def test_get_sources_from_query_with_non_default_database(): @@ -61,14 +81,20 @@ def test_get_sources_from_query_with_non_default_database(): test_query = """ select * from test2.my_schema.my_table """ - lineage_extractor = RedshiftLineageExtractor(config, report) - lineage_datasets = lineage_extractor._get_sources_from_query( + lineage_extractor = RedshiftLineageExtractor( + config, report, PipelineContext(run_id="foo") + ) + lineage_datasets, _ = lineage_extractor._get_sources_from_query( db_name="test", query=test_query ) assert len(lineage_datasets) == 1 lineage = lineage_datasets[0] - assert lineage.path == "test2.my_schema.my_table" + + assert ( + lineage.urn + == "urn:li:dataset:(urn:li:dataPlatform:redshift,test2.my_schema.my_table,PROD)" + ) def test_get_sources_from_query_with_only_table(): @@ -78,27 +104,48 @@ def test_get_sources_from_query_with_only_table(): test_query = """ select * from my_table """ - lineage_extractor = RedshiftLineageExtractor(config, report) - lineage_datasets = lineage_extractor._get_sources_from_query( + lineage_extractor = RedshiftLineageExtractor( + config, report, PipelineContext(run_id="foo") + ) + lineage_datasets, _ = lineage_extractor._get_sources_from_query( db_name="test", query=test_query ) assert len(lineage_datasets) == 1 lineage = lineage_datasets[0] - assert lineage.path == "test.public.my_table" + + assert ( + lineage.urn + == "urn:li:dataset:(urn:li:dataPlatform:redshift,test.public.my_table,PROD)" + ) -def test_get_sources_from_query_with_four_part_table_should_throw_exception(): +def test_cll(): config = RedshiftConfig(host_port="localhost:5439", database="test") report = RedshiftReport() test_query = """ - select * from database.schema.my_table.test + select a,b,c from db.public.customer inner join db.public.order on db.public.customer.id = db.public.order.customer_id """ - lineage_extractor = RedshiftLineageExtractor(config, report) - try: - lineage_extractor._get_sources_from_query(db_name="test", query=test_query) - except ValueError: - pass - - assert f"{test_query} should have thrown a ValueError exception but it didn't" + lineage_extractor = RedshiftLineageExtractor( + config, report, PipelineContext(run_id="foo") + ) + _, cll = lineage_extractor._get_sources_from_query(db_name="db", query=test_query) + + assert cll == [ + ColumnLineageInfo( + downstream=DownstreamColumnRef(table=None, column="a"), + upstreams=[], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef(table=None, column="b"), + upstreams=[], + logic=None, + ), + ColumnLineageInfo( + downstream=DownstreamColumnRef(table=None, column="c"), + upstreams=[], + logic=None, + ), + ] From 4b6b941a2abf13854511c9af0e88a17d5acfd5e6 Mon Sep 17 00:00:00 2001 From: Harsha Mandadi <115464537+harsha-mandadi-4026@users.noreply.github.com> Date: Wed, 11 Oct 2023 19:01:46 +0100 Subject: [PATCH 5/7] fix(ingest): Fix postgres lineage within views (#8906) Co-authored-by: Harshal Sheth Co-authored-by: Maggie Hays --- .../datahub/ingestion/source/sql/postgres.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py index ba8655b83446d6..a6a9d8e2c8597c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py @@ -217,14 +217,15 @@ def _get_view_lineage_elements( key = (lineage.dependent_view, lineage.dependent_schema) # Append the source table to the list. lineage_elements[key].append( - mce_builder.make_dataset_urn( - self.platform, - self.get_identifier( + mce_builder.make_dataset_urn_with_platform_instance( + platform=self.platform, + name=self.get_identifier( schema=lineage.source_schema, entity=lineage.source_table, inspector=inspector, ), - self.config.env, + platform_instance=self.config.platform_instance, + env=self.config.env, ) ) @@ -244,12 +245,13 @@ def _get_view_lineage_workunits( dependent_view, dependent_schema = key # Construct a lineage object. - urn = mce_builder.make_dataset_urn( - self.platform, - self.get_identifier( + urn = mce_builder.make_dataset_urn_with_platform_instance( + platform=self.platform, + name=self.get_identifier( schema=dependent_schema, entity=dependent_view, inspector=inspector ), - self.config.env, + platform_instance=self.config.platform_instance, + env=self.config.env, ) # use the mce_builder to ensure that the change proposal inherits From 932fbcddbf7c3201898e0918218e80c9246b0cd2 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Wed, 11 Oct 2023 14:17:02 -0400 Subject: [PATCH 6/7] refactor(ingest/dbt): move dbt tests logic to dedicated file (#8984) --- .../src/datahub/ingestion/api/common.py | 9 + .../datahub/ingestion/source/csv_enricher.py | 8 +- .../datahub/ingestion/source/dbt/dbt_cloud.py | 3 +- .../ingestion/source/dbt/dbt_common.py | 278 +----------------- .../datahub/ingestion/source/dbt/dbt_core.py | 3 +- .../datahub/ingestion/source/dbt/dbt_tests.py | 261 ++++++++++++++++ 6 files changed, 288 insertions(+), 274 deletions(-) create mode 100644 metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py diff --git a/metadata-ingestion/src/datahub/ingestion/api/common.py b/metadata-ingestion/src/datahub/ingestion/api/common.py index 778bd119615e27..a6761a3c77d5e8 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/common.py +++ b/metadata-ingestion/src/datahub/ingestion/api/common.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Tuple, TypeVar +from datahub.configuration.common import ConfigurationError from datahub.emitter.mce_builder import set_dataset_urn_to_lower from datahub.ingestion.api.committable import Committable from datahub.ingestion.graph.client import DataHubGraph @@ -75,3 +76,11 @@ def register_checkpointer(self, committable: Committable) -> None: def get_committables(self) -> Iterable[Tuple[str, Committable]]: yield from self.checkpointers.items() + + def require_graph(self, operation: Optional[str] = None) -> DataHubGraph: + if not self.graph: + raise ConfigurationError( + f"{operation or 'This operation'} requires a graph, but none was provided. " + "To provide one, either use the datahub-rest sink or set the top-level datahub_api config in the recipe." + ) + return self.graph diff --git a/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py b/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py index 7cb487a86d9310..611f0c5c52cc65 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py +++ b/metadata-ingestion/src/datahub/ingestion/source/csv_enricher.py @@ -129,11 +129,9 @@ def __init__(self, config: CSVEnricherConfig, ctx: PipelineContext): # Map from entity urn to a list of SubResourceRow. self.editable_schema_metadata_map: Dict[str, List[SubResourceRow]] = {} self.should_overwrite: bool = self.config.write_semantics == "OVERRIDE" - if not self.should_overwrite and not self.ctx.graph: - raise ConfigurationError( - "With PATCH semantics, the csv-enricher source requires a datahub_api to connect to. " - "Consider using the datahub-rest sink or provide a datahub_api: configuration on your ingestion recipe." - ) + + if not self.should_overwrite: + self.ctx.require_graph(operation="The csv-enricher's PATCH semantics flag") def get_resource_glossary_terms_work_unit( self, diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py index af9769bc9d94c9..da1ea8ecb4678a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py @@ -20,9 +20,8 @@ DBTCommonConfig, DBTNode, DBTSourceBase, - DBTTest, - DBTTestResult, ) +from datahub.ingestion.source.dbt.dbt_tests import DBTTest, DBTTestResult logger = logging.getLogger(__name__) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py index 0f5c08eb6ac549..48d2118a9b0917 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py @@ -1,11 +1,10 @@ -import json import logging import re from abc import abstractmethod from dataclasses import dataclass, field from datetime import datetime from enum import auto -from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple import pydantic from pydantic import root_validator, validator @@ -34,6 +33,12 @@ from datahub.ingestion.api.source import MetadataWorkUnitProcessor from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.common.subtypes import DatasetSubTypes +from datahub.ingestion.source.dbt.dbt_tests import ( + DBTTest, + DBTTestResult, + make_assertion_from_test, + make_assertion_result_from_test, +) from datahub.ingestion.source.sql.sql_types import ( ATHENA_SQL_TYPES_MAP, BIGQUERY_TYPES_MAP, @@ -81,20 +86,7 @@ TimeTypeClass, ) from datahub.metadata.schema_classes import ( - AssertionInfoClass, - AssertionResultClass, - AssertionResultTypeClass, - AssertionRunEventClass, - AssertionRunStatusClass, - AssertionStdAggregationClass, - AssertionStdOperatorClass, - AssertionStdParameterClass, - AssertionStdParametersClass, - AssertionStdParameterTypeClass, - AssertionTypeClass, DataPlatformInstanceClass, - DatasetAssertionInfoClass, - DatasetAssertionScopeClass, DatasetPropertiesClass, GlobalTagsClass, GlossaryTermsClass, @@ -551,134 +543,6 @@ def get_column_type( return SchemaFieldDataType(type=TypeClass()) -@dataclass -class AssertionParams: - scope: Union[DatasetAssertionScopeClass, str] - operator: Union[AssertionStdOperatorClass, str] - aggregation: Union[AssertionStdAggregationClass, str] - parameters: Optional[Callable[[Dict[str, str]], AssertionStdParametersClass]] = None - logic_fn: Optional[Callable[[Dict[str, str]], Optional[str]]] = None - - -def _get_name_for_relationship_test(kw_args: Dict[str, str]) -> Optional[str]: - """ - Try to produce a useful string for the name of a relationship constraint. - Return None if we fail to - """ - destination_ref = kw_args.get("to") - source_ref = kw_args.get("model") - column_name = kw_args.get("column_name") - dest_field_name = kw_args.get("field") - if not destination_ref or not source_ref or not column_name or not dest_field_name: - # base assertions are violated, bail early - return None - m = re.match(r"^ref\(\'(.*)\'\)$", destination_ref) - if m: - destination_table = m.group(1) - else: - destination_table = destination_ref - m = re.search(r"ref\(\'(.*)\'\)", source_ref) - if m: - source_table = m.group(1) - else: - source_table = source_ref - return f"{source_table}.{column_name} referential integrity to {destination_table}.{dest_field_name}" - - -@dataclass -class DBTTest: - qualified_test_name: str - column_name: Optional[str] - kw_args: dict - - TEST_NAME_TO_ASSERTION_MAP: ClassVar[Dict[str, AssertionParams]] = { - "not_null": AssertionParams( - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass.NOT_NULL, - aggregation=AssertionStdAggregationClass.IDENTITY, - ), - "unique": AssertionParams( - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass.EQUAL_TO, - aggregation=AssertionStdAggregationClass.UNIQUE_PROPOTION, - parameters=lambda _: AssertionStdParametersClass( - value=AssertionStdParameterClass( - value="1.0", - type=AssertionStdParameterTypeClass.NUMBER, - ) - ), - ), - "accepted_values": AssertionParams( - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass.IN, - aggregation=AssertionStdAggregationClass.IDENTITY, - parameters=lambda kw_args: AssertionStdParametersClass( - value=AssertionStdParameterClass( - value=json.dumps(kw_args.get("values")), - type=AssertionStdParameterTypeClass.SET, - ), - ), - ), - "relationships": AssertionParams( - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass._NATIVE_, - aggregation=AssertionStdAggregationClass.IDENTITY, - parameters=lambda kw_args: AssertionStdParametersClass( - value=AssertionStdParameterClass( - value=json.dumps(kw_args.get("values")), - type=AssertionStdParameterTypeClass.SET, - ), - ), - logic_fn=_get_name_for_relationship_test, - ), - "dbt_expectations.expect_column_values_to_not_be_null": AssertionParams( - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass.NOT_NULL, - aggregation=AssertionStdAggregationClass.IDENTITY, - ), - "dbt_expectations.expect_column_values_to_be_between": AssertionParams( - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass.BETWEEN, - aggregation=AssertionStdAggregationClass.IDENTITY, - parameters=lambda x: AssertionStdParametersClass( - minValue=AssertionStdParameterClass( - value=str(x.get("min_value", "unknown")), - type=AssertionStdParameterTypeClass.NUMBER, - ), - maxValue=AssertionStdParameterClass( - value=str(x.get("max_value", "unknown")), - type=AssertionStdParameterTypeClass.NUMBER, - ), - ), - ), - "dbt_expectations.expect_column_values_to_be_in_set": AssertionParams( - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass.IN, - aggregation=AssertionStdAggregationClass.IDENTITY, - parameters=lambda kw_args: AssertionStdParametersClass( - value=AssertionStdParameterClass( - value=json.dumps(kw_args.get("value_set")), - type=AssertionStdParameterTypeClass.SET, - ), - ), - ), - } - - -@dataclass -class DBTTestResult: - invocation_id: str - - status: str - execution_time: datetime - - native_results: Dict[str, str] - - -def string_map(input_map: Dict[str, Any]) -> Dict[str, str]: - return {k: str(v) for k, v in input_map.items()} - - @platform_name("dbt") @config_class(DBTCommonConfig) @support_status(SupportStatus.CERTIFIED) @@ -750,7 +614,7 @@ def create_test_entity_mcps( for upstream_urn in sorted(upstream_urns): if self.config.entities_enabled.can_emit_node_type("test"): - yield self._make_assertion_from_test( + yield make_assertion_from_test( custom_props, node, assertion_urn, @@ -759,133 +623,17 @@ def create_test_entity_mcps( if node.test_result: if self.config.entities_enabled.can_emit_test_results: - yield self._make_assertion_result_from_test( - node, assertion_urn, upstream_urn + yield make_assertion_result_from_test( + node, + assertion_urn, + upstream_urn, + test_warnings_are_errors=self.config.test_warnings_are_errors, ) else: logger.debug( f"Skipping test result {node.name} emission since it is turned off." ) - def _make_assertion_from_test( - self, - extra_custom_props: Dict[str, str], - node: DBTNode, - assertion_urn: str, - upstream_urn: str, - ) -> MetadataWorkUnit: - assert node.test_info - qualified_test_name = node.test_info.qualified_test_name - column_name = node.test_info.column_name - kw_args = node.test_info.kw_args - - if qualified_test_name in DBTTest.TEST_NAME_TO_ASSERTION_MAP: - assertion_params = DBTTest.TEST_NAME_TO_ASSERTION_MAP[qualified_test_name] - assertion_info = AssertionInfoClass( - type=AssertionTypeClass.DATASET, - customProperties=extra_custom_props, - datasetAssertion=DatasetAssertionInfoClass( - dataset=upstream_urn, - scope=assertion_params.scope, - operator=assertion_params.operator, - fields=[ - mce_builder.make_schema_field_urn(upstream_urn, column_name) - ] - if ( - assertion_params.scope - == DatasetAssertionScopeClass.DATASET_COLUMN - and column_name - ) - else [], - nativeType=node.name, - aggregation=assertion_params.aggregation, - parameters=assertion_params.parameters(kw_args) - if assertion_params.parameters - else None, - logic=assertion_params.logic_fn(kw_args) - if assertion_params.logic_fn - else None, - nativeParameters=string_map(kw_args), - ), - ) - elif column_name: - # no match with known test types, column-level test - assertion_info = AssertionInfoClass( - type=AssertionTypeClass.DATASET, - customProperties=extra_custom_props, - datasetAssertion=DatasetAssertionInfoClass( - dataset=upstream_urn, - scope=DatasetAssertionScopeClass.DATASET_COLUMN, - operator=AssertionStdOperatorClass._NATIVE_, - fields=[ - mce_builder.make_schema_field_urn(upstream_urn, column_name) - ], - nativeType=node.name, - logic=node.compiled_code or node.raw_code, - aggregation=AssertionStdAggregationClass._NATIVE_, - nativeParameters=string_map(kw_args), - ), - ) - else: - # no match with known test types, default to row-level test - assertion_info = AssertionInfoClass( - type=AssertionTypeClass.DATASET, - customProperties=extra_custom_props, - datasetAssertion=DatasetAssertionInfoClass( - dataset=upstream_urn, - scope=DatasetAssertionScopeClass.DATASET_ROWS, - operator=AssertionStdOperatorClass._NATIVE_, - logic=node.compiled_code or node.raw_code, - nativeType=node.name, - aggregation=AssertionStdAggregationClass._NATIVE_, - nativeParameters=string_map(kw_args), - ), - ) - - wu = MetadataChangeProposalWrapper( - entityUrn=assertion_urn, - aspect=assertion_info, - ).as_workunit() - - return wu - - def _make_assertion_result_from_test( - self, - node: DBTNode, - assertion_urn: str, - upstream_urn: str, - ) -> MetadataWorkUnit: - assert node.test_result - test_result = node.test_result - - assertionResult = AssertionRunEventClass( - timestampMillis=int(test_result.execution_time.timestamp() * 1000.0), - assertionUrn=assertion_urn, - asserteeUrn=upstream_urn, - runId=test_result.invocation_id, - result=AssertionResultClass( - type=AssertionResultTypeClass.SUCCESS - if test_result.status == "pass" - or ( - not self.config.test_warnings_are_errors - and test_result.status == "warn" - ) - else AssertionResultTypeClass.FAILURE, - nativeResults=test_result.native_results, - ), - status=AssertionRunStatusClass.COMPLETE, - ) - - event = MetadataChangeProposalWrapper( - entityUrn=assertion_urn, - aspect=assertionResult, - ) - wu = MetadataWorkUnit( - id=f"{assertion_urn}-assertionRunEvent-{upstream_urn}", - mcp=event, - ) - return wu - @abstractmethod def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]: # return dbt nodes + global custom properties diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py index c08295ed1dc593..dc3a84847beb24 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py @@ -26,9 +26,8 @@ DBTNode, DBTSourceBase, DBTSourceReport, - DBTTest, - DBTTestResult, ) +from datahub.ingestion.source.dbt.dbt_tests import DBTTest, DBTTestResult logger = logging.getLogger(__name__) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py new file mode 100644 index 00000000000000..721769d214d9e5 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_tests.py @@ -0,0 +1,261 @@ +import json +import re +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union + +from datahub.emitter import mce_builder +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.metadata.schema_classes import ( + AssertionInfoClass, + AssertionResultClass, + AssertionResultTypeClass, + AssertionRunEventClass, + AssertionRunStatusClass, + AssertionStdAggregationClass, + AssertionStdOperatorClass, + AssertionStdParameterClass, + AssertionStdParametersClass, + AssertionStdParameterTypeClass, + AssertionTypeClass, + DatasetAssertionInfoClass, + DatasetAssertionScopeClass, +) + +if TYPE_CHECKING: + from datahub.ingestion.source.dbt.dbt_common import DBTNode + + +@dataclass +class DBTTest: + qualified_test_name: str + column_name: Optional[str] + kw_args: dict + + +@dataclass +class DBTTestResult: + invocation_id: str + + status: str + execution_time: datetime + + native_results: Dict[str, str] + + +def _get_name_for_relationship_test(kw_args: Dict[str, str]) -> Optional[str]: + """ + Try to produce a useful string for the name of a relationship constraint. + Return None if we fail to + """ + destination_ref = kw_args.get("to") + source_ref = kw_args.get("model") + column_name = kw_args.get("column_name") + dest_field_name = kw_args.get("field") + if not destination_ref or not source_ref or not column_name or not dest_field_name: + # base assertions are violated, bail early + return None + m = re.match(r"^ref\(\'(.*)\'\)$", destination_ref) + if m: + destination_table = m.group(1) + else: + destination_table = destination_ref + m = re.search(r"ref\(\'(.*)\'\)", source_ref) + if m: + source_table = m.group(1) + else: + source_table = source_ref + return f"{source_table}.{column_name} referential integrity to {destination_table}.{dest_field_name}" + + +@dataclass +class AssertionParams: + scope: Union[DatasetAssertionScopeClass, str] + operator: Union[AssertionStdOperatorClass, str] + aggregation: Union[AssertionStdAggregationClass, str] + parameters: Optional[Callable[[Dict[str, str]], AssertionStdParametersClass]] = None + logic_fn: Optional[Callable[[Dict[str, str]], Optional[str]]] = None + + +_DBT_TEST_NAME_TO_ASSERTION_MAP: Dict[str, AssertionParams] = { + "not_null": AssertionParams( + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass.NOT_NULL, + aggregation=AssertionStdAggregationClass.IDENTITY, + ), + "unique": AssertionParams( + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass.EQUAL_TO, + aggregation=AssertionStdAggregationClass.UNIQUE_PROPOTION, + parameters=lambda _: AssertionStdParametersClass( + value=AssertionStdParameterClass( + value="1.0", + type=AssertionStdParameterTypeClass.NUMBER, + ) + ), + ), + "accepted_values": AssertionParams( + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass.IN, + aggregation=AssertionStdAggregationClass.IDENTITY, + parameters=lambda kw_args: AssertionStdParametersClass( + value=AssertionStdParameterClass( + value=json.dumps(kw_args.get("values")), + type=AssertionStdParameterTypeClass.SET, + ), + ), + ), + "relationships": AssertionParams( + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass._NATIVE_, + aggregation=AssertionStdAggregationClass.IDENTITY, + parameters=lambda kw_args: AssertionStdParametersClass( + value=AssertionStdParameterClass( + value=json.dumps(kw_args.get("values")), + type=AssertionStdParameterTypeClass.SET, + ), + ), + logic_fn=_get_name_for_relationship_test, + ), + "dbt_expectations.expect_column_values_to_not_be_null": AssertionParams( + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass.NOT_NULL, + aggregation=AssertionStdAggregationClass.IDENTITY, + ), + "dbt_expectations.expect_column_values_to_be_between": AssertionParams( + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass.BETWEEN, + aggregation=AssertionStdAggregationClass.IDENTITY, + parameters=lambda x: AssertionStdParametersClass( + minValue=AssertionStdParameterClass( + value=str(x.get("min_value", "unknown")), + type=AssertionStdParameterTypeClass.NUMBER, + ), + maxValue=AssertionStdParameterClass( + value=str(x.get("max_value", "unknown")), + type=AssertionStdParameterTypeClass.NUMBER, + ), + ), + ), + "dbt_expectations.expect_column_values_to_be_in_set": AssertionParams( + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass.IN, + aggregation=AssertionStdAggregationClass.IDENTITY, + parameters=lambda kw_args: AssertionStdParametersClass( + value=AssertionStdParameterClass( + value=json.dumps(kw_args.get("value_set")), + type=AssertionStdParameterTypeClass.SET, + ), + ), + ), +} + + +def _string_map(input_map: Dict[str, Any]) -> Dict[str, str]: + return {k: str(v) for k, v in input_map.items()} + + +def make_assertion_from_test( + extra_custom_props: Dict[str, str], + node: "DBTNode", + assertion_urn: str, + upstream_urn: str, +) -> MetadataWorkUnit: + assert node.test_info + qualified_test_name = node.test_info.qualified_test_name + column_name = node.test_info.column_name + kw_args = node.test_info.kw_args + + if qualified_test_name in _DBT_TEST_NAME_TO_ASSERTION_MAP: + assertion_params = _DBT_TEST_NAME_TO_ASSERTION_MAP[qualified_test_name] + assertion_info = AssertionInfoClass( + type=AssertionTypeClass.DATASET, + customProperties=extra_custom_props, + datasetAssertion=DatasetAssertionInfoClass( + dataset=upstream_urn, + scope=assertion_params.scope, + operator=assertion_params.operator, + fields=[mce_builder.make_schema_field_urn(upstream_urn, column_name)] + if ( + assertion_params.scope == DatasetAssertionScopeClass.DATASET_COLUMN + and column_name + ) + else [], + nativeType=node.name, + aggregation=assertion_params.aggregation, + parameters=assertion_params.parameters(kw_args) + if assertion_params.parameters + else None, + logic=assertion_params.logic_fn(kw_args) + if assertion_params.logic_fn + else None, + nativeParameters=_string_map(kw_args), + ), + ) + elif column_name: + # no match with known test types, column-level test + assertion_info = AssertionInfoClass( + type=AssertionTypeClass.DATASET, + customProperties=extra_custom_props, + datasetAssertion=DatasetAssertionInfoClass( + dataset=upstream_urn, + scope=DatasetAssertionScopeClass.DATASET_COLUMN, + operator=AssertionStdOperatorClass._NATIVE_, + fields=[mce_builder.make_schema_field_urn(upstream_urn, column_name)], + nativeType=node.name, + logic=node.compiled_code or node.raw_code, + aggregation=AssertionStdAggregationClass._NATIVE_, + nativeParameters=_string_map(kw_args), + ), + ) + else: + # no match with known test types, default to row-level test + assertion_info = AssertionInfoClass( + type=AssertionTypeClass.DATASET, + customProperties=extra_custom_props, + datasetAssertion=DatasetAssertionInfoClass( + dataset=upstream_urn, + scope=DatasetAssertionScopeClass.DATASET_ROWS, + operator=AssertionStdOperatorClass._NATIVE_, + logic=node.compiled_code or node.raw_code, + nativeType=node.name, + aggregation=AssertionStdAggregationClass._NATIVE_, + nativeParameters=_string_map(kw_args), + ), + ) + + return MetadataChangeProposalWrapper( + entityUrn=assertion_urn, + aspect=assertion_info, + ).as_workunit() + + +def make_assertion_result_from_test( + node: "DBTNode", + assertion_urn: str, + upstream_urn: str, + test_warnings_are_errors: bool, +) -> MetadataWorkUnit: + assert node.test_result + test_result = node.test_result + + assertionResult = AssertionRunEventClass( + timestampMillis=int(test_result.execution_time.timestamp() * 1000.0), + assertionUrn=assertion_urn, + asserteeUrn=upstream_urn, + runId=test_result.invocation_id, + result=AssertionResultClass( + type=AssertionResultTypeClass.SUCCESS + if test_result.status == "pass" + or (not test_warnings_are_errors and test_result.status == "warn") + else AssertionResultTypeClass.FAILURE, + nativeResults=test_result.native_results, + ), + status=AssertionRunStatusClass.COMPLETE, + ) + + return MetadataChangeProposalWrapper( + entityUrn=assertion_urn, + aspect=assertionResult, + ).as_workunit() From 1b06c6a30c8d6c0ee57f75f75ee6a436aa6c13a7 Mon Sep 17 00:00:00 2001 From: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com> Date: Thu, 12 Oct 2023 00:31:42 +0530 Subject: [PATCH 7/7] fix(ingest/snowflake): fix sample fraction for very large tables (#8988) --- .../datahub/ingestion/source/snowflake/snowflake_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py index 24275dcdff34dd..8e18d85d6f3ca3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py @@ -86,7 +86,7 @@ def get_batch_kwargs( # Fixed-size sampling can be slower than equivalent fraction-based sampling # as per https://docs.snowflake.com/en/sql-reference/constructs/sample#performance-considerations sample_pc = 100 * self.config.profiling.sample_size / table.rows_count - custom_sql = f'select * from "{db_name}"."{schema_name}"."{table.name}" TABLESAMPLE ({sample_pc:.3f})' + custom_sql = f'select * from "{db_name}"."{schema_name}"."{table.name}" TABLESAMPLE ({sample_pc:.8f})' return { **super().get_batch_kwargs(table, schema_name, db_name), # Lowercase/Mixedcase table names in Snowflake do not work by default.