Skip to content

Commit

Permalink
Merge pull request #6 from trocco-io/feature/oauth-m2m
Browse files Browse the repository at this point in the history
Implementation of oauth-m2m authentication
  • Loading branch information
yu-kioo authored Dec 26, 2024
2 parents 91d572e + 9356781 commit 272bd51
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 15 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ Databricks input plugin for Embulk loads records from Databricks.
- **product_version**: product version of user agent (string, default: "0.0.0")
- **server_hostname**: The Databricks compute resource’s Server Hostname value, see [Compute settings for the Databricks JDBC Driver](https://docs.databricks.com/en/integrations/jdbc/compute.html). (string, required)
- **http_path**: The Databricks compute resource’s HTTP Path value, see [Compute settings for the Databricks JDBC Driver](https://docs.databricks.com/en/integrations/jdbc/compute.html). (string, required)
- **personal_access_token**: The Databaricks personal_access_token, see [Authentication settings for the Databricks JDBC Driver](https://docs.databricks.com/en/integrations/jdbc/authentication.html#authentication-pat). (string, required)
- **auth_type**: The Databricks authentication type, personal access token (PAT)-based or machine-to-machine (M2M) authentication. (`pat`, `oauth-m2m`, default: `pat`)
- If **auth_type** is `pat`,
- **personal_access_token**: The Databaricks personal_access_token, see [Authentication settings for the Databricks JDBC Driver](https://docs.databricks.com/en/integrations/jdbc/authentication.html#authentication-pat). (string, required)
- If **auth_type** is `m2m-auth`,
- **oauth2_client_id**: The Databaricks oauth2_client_id, see [Use a service principal to authenticate with Databricks](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html). (string, required)
- **oauth2_client_secret**: The Databaricks oauth2_client_secret, see [Use a service principal to authenticate with Databricks](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html). (string, required)
- **catalog_name**: destination catalog name (string, optional)
- **schema_name**: destination schema name (string, optional)
- **where**: WHERE condition to filter the rows (string, default: no-condition)
Expand Down
10 changes: 6 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ plugins {
id "com.palantir.git-version" version "0.13.0"
id "com.diffplug.spotless" version "5.15.0"
id "com.adarshr.test-logger" version "3.0.0"
id "com.github.johnrengelman.shadow" version "6.0.0" apply false
}

repositories {
Expand Down Expand Up @@ -32,7 +33,7 @@ dependencies {
compileOnly("org.embulk:embulk-api:${embulkVersion}")
compileOnly("org.embulk:embulk-spi:${embulkVersion}")
compile("org.embulk:embulk-input-jdbc:0.13.2")
compile('com.databricks:databricks-jdbc:2.6.34')
compile project(path: ":shadow-databricks-jdbc", configuration: "shadow")

testImplementation "junit:junit:4.+"
testImplementation "org.embulk:embulk-junit4:${embulkVersion}"
Expand All @@ -41,9 +42,10 @@ dependencies {
testImplementation "org.embulk:embulk-formatter-csv:${embulkVersion}"
testImplementation "org.embulk:embulk-output-file:${embulkVersion}"

//SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
//SLF4J: Defaulting to no-operation (NOP) logger implementation
//SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
// Supress following logs in gradlew test.
// SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
// SLF4J: Defaulting to no-operation (NOP) logger implementation
// SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
testImplementation("org.slf4j:slf4j-simple:1.7.30")
}

Expand Down
4 changes: 3 additions & 1 deletion example/test.yml.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
server_hostname:
http_path:
personal_access_token:
oauth2_client_id:
oauth2_client_secret:
catalog_name:
schema_name:
table_prefix:
table_prefix:
another_catalog_name:
1 change: 0 additions & 1 deletion gradle/dependency-locks/compileClasspath.lockfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This is a Gradle generated file for dependency locking.
# Manual edits can break the build and are not advised.
# This file is expected to be part of source control.
com.databricks:databricks-jdbc:2.6.34
com.fasterxml.jackson.core:jackson-annotations:2.6.7
com.fasterxml.jackson.core:jackson-core:2.6.7
com.fasterxml.jackson.core:jackson-databind:2.6.7
Expand Down
1 change: 0 additions & 1 deletion gradle/dependency-locks/runtimeClasspath.lockfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This is a Gradle generated file for dependency locking.
# Manual edits can break the build and are not advised.
# This file is expected to be part of source control.
com.databricks:databricks-jdbc:2.6.34
com.fasterxml.jackson.core:jackson-annotations:2.6.7
com.fasterxml.jackson.core:jackson-core:2.6.7
com.fasterxml.jackson.core:jackson-databind:2.6.7
Expand Down
2 changes: 2 additions & 0 deletions settings.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
rootProject.name = "embulk-input-databricks"
include "shadow-databricks-jdbc"
36 changes: 36 additions & 0 deletions shadow-databricks-jdbc/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
apply plugin: "java"
apply plugin: "com.github.johnrengelman.shadow"

repositories {
mavenCentral()
}

group = "io.trocco"
version = "${rootProject.version}"
description = "A helper library for embulk-input-databricks"

sourceCompatibility = 1.8
targetCompatibility = 1.8

configurations {
runtimeClasspath {
resolutionStrategy.activateDependencyLocking()
}
shadow {
resolutionStrategy.activateDependencyLocking()
transitive = false
}
}

dependencies {
compile('com.databricks:databricks-jdbc:2.6.38')
}

shadowJar {
// suppress the following undesirable log (https://stackoverflow.com/a/61475766/24393181)
//
// ERROR StatusLogger Unrecognized format specifier [d]
// ERROR StatusLogger Unrecognized conversion specifier [d] starting at position 16 in conversion pattern.
// ...
exclude "**/Log4j2Plugins.dat"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# This is a Gradle generated file for dependency locking.
# Manual edits can break the build and are not advised.
# This file is expected to be part of source control.
com.databricks:databricks-jdbc:2.6.38
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# This is a Gradle generated file for dependency locking.
# Manual edits can break the build and are not advised.
# This file is expected to be part of source control.
58 changes: 53 additions & 5 deletions src/main/java/org/embulk/input/DatabricksInputPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Properties;
import org.embulk.config.ConfigException;
Expand Down Expand Up @@ -30,8 +32,21 @@ public interface DatabricksPluginTask extends PluginTask {
@Config("http_path")
public String getHTTPPath();

@Config("auth_type")
@ConfigDefault("\"pat\"") // oauth-m2m or pat
public String getAuthType();

@Config("personal_access_token")
public String getPersonalAccessToken();
@ConfigDefault("null")
public Optional<String> getPersonalAccessToken();

@Config("oauth2_client_id")
@ConfigDefault("null")
public Optional<String> getOauth2ClientId();

@Config("oauth2_client_secret")
@ConfigDefault("null")
public Optional<String> getOauth2ClientSecret();

@Config("catalog_name")
@ConfigDefault("null")
Expand All @@ -54,6 +69,25 @@ public interface UserAgentEntry extends Task {
@ConfigDefault("\"0.0.0\"")
public String getProductVersion();
}

static String fetchPersonalAccessToken(DatabricksPluginTask t) {
return validatePresence(t.getPersonalAccessToken(), "personal_access_token");
}

static String fetchOauth2ClientId(DatabricksPluginTask t) {
return validatePresence(t.getOauth2ClientId(), "oauth2_client_id");
}

static String fetchOauth2ClientSecret(DatabricksPluginTask t) {
return validatePresence(t.getOauth2ClientSecret(), "oauth2_client_secret");
}
}

private static <T> T validatePresence(Optional<T> val, String varName) {
if (val.isPresent()) {
return val.get();
}
throw new ConfigException(String.format("%s must not be null.", varName));
}

@Override
Expand All @@ -79,9 +113,22 @@ protected JdbcInputConnection newConnection(PluginTask task) throws SQLException
String url = String.format("jdbc:databricks://%s:443", t.getServerHostname());
Properties props = new java.util.Properties();
props.put("httpPath", t.getHTTPPath());
props.put("AuthMech", "3");
props.put("UID", "token");
props.put("PWD", t.getPersonalAccessToken());
String authType = t.getAuthType();
switch (authType) {
case "pat":
props.put("AuthMech", "3");
props.put("UID", "token");
props.put("PWD", DatabricksPluginTask.fetchPersonalAccessToken(t));
break;
case "oauth-m2m":
props.put("AuthMech", "11");
props.put("Auth_Flow", "1");
props.put("OAuth2ClientId", DatabricksPluginTask.fetchOauth2ClientId(t));
props.put("OAuth2Secret", DatabricksPluginTask.fetchOauth2ClientSecret(t));
break;
default:
throw new ConfigException(String.format("unknown auth_type '%s'", authType));
}
props.put("SSL", "1");
if (t.getCatalogName().isPresent()) {
props.put("ConnCatalog", t.getCatalogName().get());
Expand All @@ -104,10 +151,11 @@ protected JdbcInputConnection newConnection(PluginTask task) throws SQLException

@Override
protected void logConnectionProperties(String url, Properties props) {
List<String> maskedKeys = Arrays.asList("PWD", "OAuth2Secret");
Properties maskedProps = new Properties();
for (Object keyObj : props.keySet()) {
String key = (String) keyObj;
String maskedVal = key.equals("PWD") ? "***" : props.getProperty(key);
String maskedVal = maskedKeys.contains(key) ? "***" : props.getProperty(key);
maskedProps.setProperty(key, maskedVal);
}
super.logConnectionProperties(url, maskedProps);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package org.embulk.input.databricks;

import static org.embulk.test.EmbulkTests.readFile;

import java.io.IOException;
import java.nio.file.Path;
import java.util.function.Function;
import org.embulk.config.ConfigException;
import org.embulk.config.ConfigSource;
import org.embulk.exec.PartialExecutionException;
import org.embulk.input.databricks.util.ConfigUtil;
import org.embulk.input.databricks.util.ConnectionUtil;
import org.junit.Assert;
import org.junit.Test;

public class TestDatabricksInputPluginAuthType extends AbstractTestDatabricksInputPlugin {
@Test
public void testAuthTypeDefault() throws IOException {
final ConfigUtil.TestTask t = ConfigUtil.createTestTask();
testRun((x) -> x.set("personal_access_token", t.getPersonalAccessToken()));
}

@Test
public void testAuthTypePat() throws IOException {
final ConfigUtil.TestTask t = ConfigUtil.createTestTask();
testRun(
(x) -> x.set("auth_type", "pat").set("personal_access_token", t.getPersonalAccessToken()));
}

@Test
public void testAtuTypePatWithoutPersonalAccessToken() {
testConfigException((x) -> x.set("auth_type", "pat"), "personal_access_token");
}

@Test
public void testAuthTypeM2MOauth() throws IOException {
final ConfigUtil.TestTask t = ConfigUtil.createTestTask();
testRun(
(x) ->
x.set("auth_type", "oauth-m2m")
.set("oauth2_client_id", t.getOauth2ClientId())
.set("oauth2_client_secret", t.getOauth2ClientSecret()));
}

@Test
public void testAuthTypeM2MOauthWithoutOauth2ClientId() {
final ConfigUtil.TestTask t = ConfigUtil.createTestTask();
testConfigException(
(x) ->
x.set("auth_type", "oauth-m2m").set("oauth2_client_secret", t.getOauth2ClientSecret()),
"oauth2_client_id");
}

@Test
public void testAuthTypeM2MOauthWithoutOauth2ClientSecret() {
final ConfigUtil.TestTask t = ConfigUtil.createTestTask();
testConfigException(
(x) -> x.set("auth_type", "oauth-m2m").set("oauth2_client_id", t.getOauth2ClientId()),
"oauth2_client_secret");
}

@Test
public void testInvalidAuthType() {
testConfigException((x) -> x.set("auth_type", "invalid"), "auth_type");
}

private void testRun(Function<ConfigSource, ConfigSource> setConfigSource) throws IOException {
final String quotedFullTableName = ConfigUtil.createRandomQuotedFullTableName();
ConnectionUtil.run(
String.format("create table %s (_c0 LONG)", quotedFullTableName),
String.format("INSERT INTO %s VALUES (1)", quotedFullTableName));
final ConfigUtil.TestTask t = ConfigUtil.createTestTask();
ConfigSource configSource =
createMinimumConfigSource()
.set("query", String.format("select * from %s", quotedFullTableName));
Path out = embulk.createTempFile("csv");
embulk.runInput(setConfigSource.apply(configSource), out);
Assert.assertEquals("1\n", readFile(out));
}

private void testConfigException(
Function<ConfigSource, ConfigSource> setConfigSource, String containedMessage) {
ConfigSource configSource = createMinimumConfigSource();
Path out = embulk.createTempFile("csv");
Exception e =
Assert.assertThrows(
PartialExecutionException.class,
() -> embulk.runInput(setConfigSource.apply(configSource), out));
Assert.assertTrue(e.getCause() instanceof ConfigException);
Assert.assertTrue(
String.format("「%s」 does not contains '%s'", e.getMessage(), containedMessage),
e.getMessage().contains(containedMessage));
}

private ConfigSource createMinimumConfigSource() {
final ConfigUtil.TestTask t = ConfigUtil.createTestTask();
return ConfigUtil.emptyConfigSource()
.set("type", "databricks")
.set("server_hostname", t.getServerHostname())
.set("http_path", t.getHTTPPath());
}
}
13 changes: 11 additions & 2 deletions src/test/java/org/embulk/input/databricks/util/ConfigUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ public interface TestTask extends Task {
@Config("personal_access_token")
public String getPersonalAccessToken();

@Config("oauth2_client_id")
public String getOauth2ClientId();

@Config("oauth2_client_secret")
public String getOauth2ClientSecret();

@Config("catalog_name")
public String getCatalogName();

Expand Down Expand Up @@ -80,11 +86,14 @@ public static String createRandomTableName() {
return createTableName(UUID.randomUUID().toString());
}

public static ConfigSource emptyConfigSource() {
return CONFIG_MAPPER_FACTORY.newConfigSource();
}

public static ConfigSource createBasePluginConfigSource() {
final TestTask t = createTestTask();

return CONFIG_MAPPER_FACTORY
.newConfigSource()
return emptyConfigSource()
.set("type", "databricks")
.set("server_hostname", t.getServerHostname())
.set("http_path", t.getHTTPPath())
Expand Down

0 comments on commit 272bd51

Please sign in to comment.