CombinatorialUtil.java

/*
 * #%L
 * wcm.io
 * %%
 * Copyright (C) 2020 wcm.io
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package io.wcm.qa.glnm.junit.combinatorial;

import static com.google.common.collect.Lists.transform;
import static io.wcm.qa.glnm.junit.combinatorial.ReAnnotationUtils.findRepeatableAnnotations;
import static java.util.stream.Collectors.toList;
import static org.junit.platform.commons.support.AnnotationSupport.findAnnotation;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.reflect.TypeUtils;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.CombinatorialTestMethodContext;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
import org.junit.jupiter.params.support.AnnotationConsumer;
import org.junit.platform.commons.util.ExceptionUtils;
import org.junit.platform.commons.util.ReflectionUtils;

import io.wcm.qa.glnm.exceptions.GaleniumException;

final class CombinatorialUtil {

  private CombinatorialUtil() {
    // do not instantiate
  }

  /**
   * <p>
   * extractArgumentProviders.
   * </p>
   *
   * @param context a {@link org.junit.jupiter.api.extension.ExtensionContext} object.
   * @return a list of {@link ArgumentsProvider ArgumentsProviders}
   * @since 5.0.0
   */
  public static List<ArgumentsProvider> extractArgumentProviders(ExtensionContext context) {
    return extractFromAnnotations(context, ArgumentsSource.class, t -> s -> providerFromSource(t, s));
  }

  private static <T, A extends Annotation> List<T> collectFromAnnotations(
      Class<A> annotationType,
      Annotation[] declaredAnnotations,
      Function<Annotation, Function<A, T>> mappingProducer) {

    List<T> result = new ArrayList<T>();

    for (Annotation annotation : declaredAnnotations) {
      List<T> collectedFromAnnotation = findRepeatableAnnotations(annotation, annotationType)
          .stream()
          .map(mappingProducer.apply(annotation))
          .filter(ObjectUtils::allNotNull)
          .collect(toList());
      result.addAll(collectedFromAnnotation);
    }

    return result;
  }

  @SuppressWarnings("unchecked")
  private static CombinableProvider extensionFromSource(Annotation original, CombinableSource source) {
    CombinableProvider providerInstance = ReflectionUtils.newInstance(source.value());
    if (providerInstance instanceof AnnotationConsumer<?>) {
      feedAnnotationConsumer(original, (AnnotationConsumer)providerInstance);
    }
    if (providerInstance.providedType().isAssignableFrom(Extension.class)) {
      return providerInstance;
    }
    return null;
  }

  private static <T, A extends Annotation> List<T> extractFromAnnotations(
      ExtensionContext extensionContext,
      Class<A> annotationType,
      Function<Annotation, Function<A, T>> mappingProducer) {
    Method testMethod = extensionContext.getRequiredTestMethod();
    Annotation[] declaredAnnotations = testMethod.getDeclaredAnnotations();
    return collectFromAnnotations(annotationType, declaredAnnotations, mappingProducer);
  }

  @SuppressWarnings("unchecked")
  private static void feedAnnotationConsumer(Annotation annotation, AnnotationConsumer consumer) {
    Class<? extends Annotation> consumedType = getConsumedAnnotationType(consumer);
    Class<? extends Annotation> originalType = annotation.annotationType();
    if (consumedType.isAssignableFrom(originalType)) {
      consumer.accept(annotation);
    }
    else {
      Optional<? extends Annotation> optionalAnnotation = findAnnotation(originalType, consumedType);
      consumer.accept(optionalAnnotation.orElseThrow(new Supplier<GaleniumException>() {

        @Override
        public GaleniumException get() {
          return new GaleniumException("No annotation found to consume for : " + consumer);
        }
      }));
    }
  }

  private static Class<? extends Annotation> getConsumedAnnotationType(AnnotationConsumer annotationConsumer) {
    Class<? extends AnnotationConsumer> consumerClass = annotationConsumer.getClass();
    Type[] interfaces = consumerClass.getGenericInterfaces();
    for (Type type : interfaces) {
      if (type instanceof ParameterizedType
          && TypeUtils.isAssignable(type, AnnotationConsumer.class)) {
        ParameterizedType consumerType = (ParameterizedType)type;
        Type consumedAnnotationType = consumerType.getActualTypeArguments()[0];
        return rawType(consumedAnnotationType);
      }
    }
    throw new GaleniumException("Did not find type of consumed annotation: " + annotationConsumer);
  }

  private static List<List<Combinable>> providersToArguments(
      List<ArgumentsProvider> providers,
      ExtensionContext context) {
    return transform(
        providers,
        p -> arguments(p, context));
  }

  @SuppressWarnings("unchecked")
  private static Class<? extends Annotation> rawType(Type argumentType) {
    return (Class<? extends Annotation>)TypeUtils.getRawType(argumentType, Annotation.class);
  }

  static List<Combinable> arguments(
      ArgumentsProvider provider,
      ExtensionContext context) {
    try {
      return provider.provideArguments(context)
          .map(Combinable::new)
          .collect(toList());
    }
    catch (Exception e) {
      throw ExceptionUtils.throwAsUncheckedException(e);
    }
  }

  static Object[] consumedArguments(
      Object[] arguments,
      CombinatorialTestMethodContext methodContext) {
    int parameterCount = methodContext.getParameterCount();
    return methodContext.hasAggregator() ? arguments
        : (arguments.length > parameterCount ? Arrays.copyOf(arguments, parameterCount) : arguments);
  }

  static List<List<Combinable>> extractArguments(ExtensionContext context) {
    return providersToArguments(extractArgumentProviders(context), context);
  }

  static List<CombinableProvider> extractExtensionSources(ExtensionContext context) {
    return extractFromAnnotations(context, CombinableSource.class, t -> s -> extensionFromSource(t, s));
  }

  static <T> List<T> filter(Class<T> type, List<Combinable> values) {
    return values.stream()
        .map(Combinable::getValue)
        .filter(v -> type.isInstance(v))
        .map(v -> type.cast(v))
        .collect(toList());
  }

  static Arguments flattenArgumentsList(List<Arguments> args) {
    return Arguments.of(listToArray(args));
  }

  static Object[] listToArray(List<Arguments> list) {
    Object[] listAsSingleArray = new Object[] {};
    for (Arguments args : list) {
      listAsSingleArray = ArrayUtils.addAll(listAsSingleArray, args.get());
    }
    return listAsSingleArray;
  }

  static ArgumentsProvider providerFromSource(Annotation original, ArgumentsSource source) {
    Class<? extends ArgumentsProvider> providerType = source.value();
    ArgumentsProvider providerInstance = ReflectionUtils.newInstance(providerType);
    if (providerInstance instanceof AnnotationConsumer<?>) {
      feedAnnotationConsumer(original, (AnnotationConsumer)providerInstance);
    }
    return providerInstance;
  }


}