Skip to content

Commit

Permalink
feat(endpoints)!: integrate endpoints 2.0 (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ganesh Jangir authored Oct 27, 2022
1 parent 7c67c40 commit 092b9ed
Show file tree
Hide file tree
Showing 27 changed files with 220 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

public struct DefaultSDKRuntimeConfiguration: SDKRuntimeConfiguration {
public var endpoint: String?
public let retryer: SDKRetryer
public var clientLogMode: ClientLogMode
public var logger: LogAgent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public protocol SDKRuntimeConfiguration {
var logger: LogAgent {get}
var clientLogMode: ClientLogMode {get}
var retryer: SDKRetryer {get}
var endpoint: String? {get set}
}

public extension SDKRuntimeConfiguration {
Expand Down
35 changes: 34 additions & 1 deletion Packages/ClientRuntime/Sources/Networking/Endpoint.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,49 @@ public struct Endpoint: Hashable {
public let host: String
public let port: Int16

public let headers: Headers?
public let properties: [String: AnyHashable]

public init(urlString: String,
headers: Headers? = nil,
properties: [String: AnyHashable] = [:]) throws {
guard let url = URL(string: urlString) else {
throw ClientError.unknownError("invalid url \(urlString)")
}

try self.init(url: url, headers: headers, properties: properties)
}

public init(url: URL,
headers: Headers? = nil,
properties: [String: AnyHashable] = [:]) throws {
guard let host = url.host else {
throw ClientError.unknownError("invalid host \(String(describing: url.host))")
}

self.init(host: host,
path: url.path,
port: Int16(url.port ?? 443),
queryItems: url.toQueryItems(),
protocolType: ProtocolType(rawValue: url.scheme ?? ProtocolType.https.rawValue),
headers: headers,
properties: properties)
}

public init(host: String,
path: String = "/",
port: Int16 = 443,
queryItems: [URLQueryItem]? = nil,
protocolType: ProtocolType? = .https) {
protocolType: ProtocolType? = .https,
headers: Headers? = nil,
properties: [String: AnyHashable] = [:]) {
self.host = host
self.path = path
self.port = port
self.queryItems = queryItems
self.protocolType = protocolType
self.headers = headers
self.properties = properties
}
}

Expand Down
32 changes: 30 additions & 2 deletions Packages/ClientRuntime/Sources/Networking/Http/Headers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import AwsCommonRuntimeKit

public struct Headers: Equatable {
public struct Headers: Hashable {
public var headers: [Header] = []

/// Creates an empty instance.
Expand Down Expand Up @@ -161,6 +161,17 @@ public struct Headers: Equatable {
}
}

extension Headers: Equatable {
/// Returns a boolean value indicating whether two values are equal irrespective of order.
/// - Parameters:
/// - lhs: The first `Headers` to compare.
/// - rhs: The second `Headers` to compare.
/// - Returns: `true` if the two values are equal irrespective of order, otherwise `false`.
public static func == (lhs: Headers, rhs: Headers) -> Bool {
return lhs.headers.sorted() == rhs.headers.sorted()
}
}

extension Array where Element == Header {
/// Case-insensitively finds the index of an `Header` with the provided name, if it exists.
func index(of name: String) -> Int? {
Expand All @@ -175,7 +186,7 @@ extension Array where Element == Header {
}
}

public struct Header: Equatable {
public struct Header: Hashable {
public var name: String
public var value: [String]

Expand All @@ -190,6 +201,23 @@ public struct Header: Equatable {
}
}

extension Header: Equatable {
public static func == (lhs: Header, rhs: Header) -> Bool {
return lhs.name == rhs.name && lhs.value.sorted() == rhs.value.sorted()
}
}

extension Header: Comparable {
/// Compares two `Header` instances by name.
/// - Parameters:
/// - lhs: The first `Header` to compare.
/// - rhs: The second `Header` to compare.
/// - Returns: `true` if the first `Header`'s name is less than the second `Header`'s name, otherwise `false`.
public static func < (lhs: Header, rhs: Header) -> Bool {
return lhs.name < rhs.name
}
}

extension Headers {
func toHttpHeaders() -> HttpHeaders {
let httpHeaders = HttpHeaders()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,14 @@

import struct Foundation.URLQueryItem
public typealias URLQueryItem = Foundation.URLQueryItem

extension URLQueryItem: Comparable {
/// Compares two `URLQueryItem` instances by their `name` property.
/// - Parameters:
/// - lhs: The first `URLQueryItem` to compare.
/// - rhs: The second `URLQueryItem` to compare.
/// - Returns: `true` if the `name` property of `lhs` is less than the `name` property of `rhs`.
public static func < (lhs: URLQueryItem, rhs: URLQueryItem) -> Bool {
lhs.name < rhs.name
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0.

import XCTest
import AwsCommonRuntimeKit
import AwsCCommon

open class CrtXCBaseTestCase: XCTestCase {
let allocator = TracingAllocator(tracingStacksOf: defaultAllocator)
let logging = Logger(pipe: stdout, level: .trace, allocator: defaultAllocator)

open override func setUp() {
super.setUp()

AwsCommonRuntimeKit.initialize(allocator: self.allocator)
}

open override func tearDown() {
AwsCommonRuntimeKit.cleanUp()

allocator.dump()
XCTAssertEqual(allocator.count, 0,
"Memory was leaked: \(allocator.bytes) bytes in \(allocator.count) allocations")

super.tearDown()
}
}
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ kotlin.code.style=official
# config

# codegen
smithyVersion=[1.25.0,1.26.0[
smithyVersion=[1.26.0,1.27.0[
smithyGradleVersion=0.6.0

# kotlin
Expand Down
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ pluginManagement {
rootProject.name = "smithy-swift"

include("smithy-swift-codegen")
include("smithy-swift-codegen-test")
2 changes: 1 addition & 1 deletion smithy-swift-codegen-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repositories {
task<Exec>("buildGeneratedSDK") {
val pathToGeneratedSDK = "$buildDir/smithyprojections/smithy-swift-codegen-test/source/swift-codegen"
workingDir(".")
commandLine("which", "swift")
commandLine("swift", "--version")
isIgnoreExitValue=true
dependsOn(tasks["build"])

Expand Down
1 change: 1 addition & 0 deletions smithy-swift-codegen/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies {
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
testImplementation("org.junit.jupiter:junit-jupiter:$junitVersion")
testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion")
implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion")
}

jacoco {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ object ClientRuntimeTypes {
}

object Core {
val Endpoint = runtimeSymbol("Endpoint")
val ByteStream = runtimeSymbol("ByteStream")
val Date = runtimeSymbol("Date")
val Data = runtimeSymbol("Data")
Expand All @@ -96,6 +97,12 @@ object ClientRuntimeTypes {
val PaginateToken = runtimeSymbol("PaginateToken")
val PaginatorSequence = runtimeSymbol("PaginatorSequence")
}

object Test {
val CrtXCBaseTestCase = buildSymbol {
name = "CrtXCBaseTestCase"
}
}
}

private fun runtimeSymbol(name: String): Symbol = buildSymbol {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Void>() {
private val integrations: List<SwiftIntegration>
private val protocolGenerator: ProtocolGenerator?
private val baseGenerationContext: GenerationContext
private val protocolContext: ProtocolGenerator.GenerationContext?

init {
LOGGER.info("Attempting to discover SwiftIntegration from classpath...")
Expand Down Expand Up @@ -78,6 +79,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Void>() {
}

baseGenerationContext = GenerationContext(model, symbolProvider, settings, protocolGenerator, integrations)

protocolContext = protocolGenerator?.let { ProtocolGenerator.GenerationContext(settings, model, service, symbolProvider, integrations, it.protocol, writers) }
}

fun preprocessModel(model: Model): Model {
Expand Down Expand Up @@ -115,31 +118,25 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Void>() {
serviceShapes.forEach { it.accept(this) }
var shouldGenerateTestTarget = false
protocolGenerator?.apply {
val ctx = ProtocolGenerator.GenerationContext(
settings,
model,
service,
symbolProvider,
integrations,
protocolGenerator.protocol,
writers
)
LOGGER.info("[${service.id}] Generating serde for protocol ${protocolGenerator.protocol}")
generateSerializers(ctx)
generateDeserializers(ctx)
generateCodableConformanceForNestedTypes(ctx)

initializeMiddleware(ctx)

LOGGER.info("[${service.id}] Generating unit tests for protocol ${protocolGenerator.protocol}")
val numProtocolUnitTestsGenerated = generateProtocolUnitTests(ctx)
shouldGenerateTestTarget = (numProtocolUnitTestsGenerated > 0)

LOGGER.info("[${service.id}] Generating service client for protocol ${protocolGenerator.protocol}")

generateProtocolClient(ctx)
protocolContext?.let { ctx ->
LOGGER.info("[${service.id}] Generating serde for protocol ${protocolGenerator.protocol}")
generateSerializers(ctx)
generateDeserializers(ctx)
generateCodableConformanceForNestedTypes(ctx)

initializeMiddleware(ctx)

LOGGER.info("[${service.id}] Generating unit tests for protocol ${protocolGenerator.protocol}")
val numProtocolUnitTestsGenerated = generateProtocolUnitTests(ctx)
shouldGenerateTestTarget = (numProtocolUnitTestsGenerated > 0)

LOGGER.info("[${service.id}] Generating service client for protocol ${protocolGenerator.protocol}")

generateProtocolClient(ctx)

integrations.forEach { it.writeAdditionalFiles(baseGenerationContext, ctx, writers) }
}
}
integrations.forEach { it.writeAdditionalFiles(baseGenerationContext, writers) }

println("Flushing swift writers")
val dependencies = writers.dependencies
Expand Down Expand Up @@ -176,9 +173,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Void>() {
}

override fun serviceShape(shape: ServiceShape): Void? {
writers.useShapeWriter(shape) {
writer: SwiftWriter ->
ServiceGenerator(settings, model, symbolProvider, writer, writers, protocolGenerator).render()
writers.useShapeWriter(shape) { writer: SwiftWriter ->
ServiceGenerator(settings, model, symbolProvider, writer, writers, protocolGenerator, protocolContext).render()
ServiceNamespaceGenerator(settings, model, symbolProvider, writer).render()
}
return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package software.amazon.smithy.swift.codegen
import software.amazon.smithy.codegen.core.Symbol

abstract class Middleware(private val writer: SwiftWriter, shapeSymbol: Symbol, step: OperationStep) {
open val id: String get() = typeName
open val typeName: String = "${shapeSymbol.name}Middleware"

open val inputType: Symbol = step.inputType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MiddlewareGenerator(
fun generate() {

writer.openBlock("public struct ${middleware.typeName}: ${middleware.getTypeInheritance()} {", "}") {
writer.write("public let id: \$N = \"${middleware.typeName}\"", SwiftTypes.String)
writer.write("public let id: \$N = \"${middleware.id}\"", SwiftTypes.String)
writer.write("")
middleware.properties.forEach {
val memberName = it.key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.PaginatedTrait
import software.amazon.smithy.swift.codegen.core.CodegenContext
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.model.SymbolProperty
import software.amazon.smithy.swift.codegen.model.camelCaseName
Expand All @@ -29,7 +30,7 @@ class PaginatorGenerator : SwiftIntegration {
override fun enabledForService(model: Model, settings: SwiftSettings): Boolean =
model.operationShapes.any { it.hasTrait<PaginatedTrait>() }

override fun writeAdditionalFiles(ctx: CodegenContext, delegator: SwiftDelegator) {
override fun writeAdditionalFiles(ctx: CodegenContext, protoCtx: ProtocolGenerator.GenerationContext, delegator: SwiftDelegator) {
val service = ctx.model.expectShape<ServiceShape>(ctx.settings.service)
val paginatedIndex = PaginatedIndex.of(ctx.model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.swift.codegen.integration.DefaultServiceConfig
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SectionId
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
import software.amazon.smithy.swift.codegen.model.camelCaseName

Expand All @@ -27,12 +29,16 @@ class ServiceGenerator(
private val symbolProvider: SymbolProvider,
private val writer: SwiftWriter,
private val delegator: SwiftDelegator,
private val protocolGenerator: ProtocolGenerator? = null
private val protocolGenerator: ProtocolGenerator? = null,
private val protocolGenerationContext: ProtocolGenerator.GenerationContext?
) {
private var service = settings.getService(model)
private val serviceSymbol: Symbol by lazy {
symbolProvider.toSymbol(service)
}
private val serviceConfig: DefaultServiceConfig by lazy {
DefaultServiceConfig(writer, serviceSymbol.name)
}
private val rootNamespace = settings.moduleName

companion object {
Expand Down Expand Up @@ -107,7 +113,7 @@ class ServiceGenerator(
* We will generate the following:
* ```
* public protocol ExampleServiceProtocol {
* func getFoo(input: GetFooInput, completion: @escaping (SdkResult<GetFooOutput, GetFooError>) -> Void)
* func getFoo(input: GetFooInput) async throws -> GetFooResponse
* }
* ```
*/
Expand All @@ -127,8 +133,20 @@ class ServiceGenerator(
}
.closeBlock("}")
.write("")

val sectionContext = mapOf(
"serviceSymbol" to serviceSymbol,
"protocolGenerator" to protocolGenerator,
"protocolGenerationContext" to protocolGenerationContext
)
writer.declareSection(ConfigurationProtocolSectionId, sectionContext) {
writer.openBlock("public protocol \$L : \$L {", "}", serviceConfig.typeProtocol, serviceConfig.getTypeInheritance()) {
}
}.write("")
}

object ConfigurationProtocolSectionId : SectionId

/*
Renders the Operation Error enum
*/
Expand Down
Loading

0 comments on commit 092b9ed

Please sign in to comment.