package org.jboss.arquillian.testenricher.ejb;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.List;
import javax.ejb.EJB;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import org.jboss.arquillian.spi.TestEnricher;
{
private static final String ANNOTATION_NAME = "javax.ejb.EJB";
private static final String ANNOTATION_FIELD_BEAN_INTERFACE = "beanInterface";
private static final String ANNOTATION_FIELD_MAPPED_NAME = "mappedName";
public void enrich(Object testCase)
{
if(SecurityActions.isClassPresent(ANNOTATION_NAME))
{
injectClass(testCase);
}
}
public Object[]
resolve(Method method)
{
return new Object[method.getParameterTypes().length];
}
throws IllegalArgumentException
{
if (clazz == null)
{
throw new IllegalArgumentException("clazz must be specified");
}
if (annotation == null)
{
throw new IllegalArgumentException("annotation must be specified");
}
return SecurityActions.getFieldsWithAnnotation(clazz, annotation);
}
{
try
{
@SuppressWarnings("unchecked")
Class<? extends Annotation> ejbAnnotation = (Class<? extends Annotation>)SecurityActions.getThreadContextClassLoader().loadClass(ANNOTATION_NAME);
List<Field> annotatedFields = SecurityActions.getFieldsWithAnnotation(
testCase.getClass(),
ejbAnnotation);
for(Field field : annotatedFields)
{
if(field.get(testCase) == null)
{
EJB fieldAnnotation = (EJB) field.getAnnotation(ejbAnnotation);
Object ejb = lookupEJB(field.getType(), fieldAnnotation.mappedName());
field.set(testCase, ejb);
}
}
List<Method> methods = SecurityActions.getMethodsWithAnnotation(
testCase.getClass(),
ejbAnnotation);
for(Method method : methods)
{
if(method.getParameterTypes().length != 1)
{
throw new RuntimeException("@EJB only allowed on single argument methods");
}
if(!method.getName().startsWith("set"))
{
throw new RuntimeException("@EJB only allowed on 'set' methods");
}
EJB parameterAnnotation = null;
for (Annotation annotation : method.getParameterAnnotations()[0])
{
if (EJB.class.isAssignableFrom(annotation.annotationType()))
{
parameterAnnotation = (EJB) annotation;
}
}
String mappedName = parameterAnnotation == null ? null : parameterAnnotation.mappedName();
Object ejb = lookupEJB(method.getParameterTypes()[0], mappedName);
method.invoke(testCase, ejb);
}
}
catch (Exception e)
{
throw new RuntimeException("Could not inject members", e);
}
}
protected Object
lookupEJB(Class<?> fieldType, String mappedName)
throws Exception
{
InitialContext initcontext = createContext();
String[] jndiNames = {
"java:global/test.ear/test/" + fieldType.getSimpleName() + "Bean",
"java:global/test.ear/test/" + fieldType.getSimpleName(),
"java:global/test/" + fieldType.getSimpleName(),
"java:global/test/" + fieldType.getSimpleName() + "Bean",
"java:global/test/" + fieldType.getSimpleName() + "/no-interface",
"test/" + fieldType.getSimpleName() + "Bean/local",
"test/" + fieldType.getSimpleName() + "Bean/remote",
"test/" + fieldType.getSimpleName() + "/no-interface",
fieldType.getSimpleName() + "Bean/local",
fieldType.getSimpleName() + "Bean/remote",
fieldType.getSimpleName() + "/no-interface"
};
if ((mappedName != null) && (!mappedName.equals("")))
{
jndiNames = new String[]{ mappedName };
}
for(String jndiName : jndiNames)
{
try
{
return initcontext.lookup(jndiName);
}
catch (NamingException e)
{
}
}
throw new NamingException("No EJB found in JNDI, tried the following names: " + joinJndiNames(jndiNames));
}
{
return new InitialContext();
}
{
StringBuilder sb = new StringBuilder();
for(String string: strings)
{
sb.append(string).append(", ");
}
return sb.toString();
}
}