背景
每次在工作中使用 spring,总会被其流畅代码思路和迷人的架构模式以及设计模式吸引,也看过一阵子源码,但总觉得了解的是是而非,所以想着何不自己手写一个 springIoc 用来梳理巩固自己所学的知识,以及自我充电
Spring 核心
- **控制反转 (**IoC,Inversion of Control)(本篇文章实现)
传统的 Java 开发模式中,当需要一个对象时,我们使用 new 或者通过 getInstance 等直接或者间接调用构造方法创建一个对象,而在 Spring 开发模式中,Spring 容器使用工厂模式为我们创建了所需要的对象,不需要我们自己去创建了,直接调用 Spring 提供的对象就可以了,这就是控制反转,相信我们在用 Spring 的时候,用 XML 或者注解了解过
- 面向切面编程(AOP)(后续实现)
在面向对象编程(OOP)中,我们将事务纵向抽成一个个的对象,而在面向切面编程中,我们将一个个的对象某些类似的方面横向抽成一个切面,对这个切面进行一些如权限控制,事务管理,日志记录等公用操作处理的过程,就是面向切面编程的思想。
面向切面编程也是 Spring 非常具有特色的功能,在实际工作中也非常广泛应用,就比如之前在公司我用 springAop 和 spring SPEL 机制实现的一部分功能
前置环境准备
假如我是一个 Spring 开发人员,我要实现一个 IoC,我需要怎么做呢?
- 需要一个解析获取用户定义包扫描路径 比如
component-scan
注解扫描路径,可以放在 XML 或者文件等等,这里放在applicationContext.xml
- 获取所有 java.class 文件,Java 编译好的文件都放在 xxx/target/classes 下,如果是测试包则放在 xxx/target/test-classes 下,所以我们需要递归获取这下面的文件
- 解析处理文件全路径名称,根据系统环境区分
/
和'',拼接类路径 - 根据类路径,进行反射构造实例 class.forName(xxxx)
- 根据自定义注解标识,哪些类需要构造实例
- 构造实例后需要填充属性,注解标识属性实例构造填充到该实例
- 创建一个存取构造好的 bean 的容器
搭建项目
maven 依赖
创建一个 maven 项目,因为需要解析 XML 获取 component-scan 导入 dom4j依赖
和 lombok依赖
以及相关 log 日志依赖
<dependencies>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId>
<version>2.17.1</version>
</dependency>
<!-- Log4j Core -->
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-core</artifactId>
<version>2.17.1</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.30</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.dom4j/dom4j -->
<dependency>
<groupId>org.dom4j</groupId>
<artifactId>dom4j</artifactId>
<version>2.1.4</version>
</dependency>
</dependencies>
包结构
- com.xiaohu.springioc 下面的包业务代码,写控制层,service 层代码
- org.springframework.xx 下面的包为自己简易的写 Spring 相关 IoC
编码
实现 bean 容器
package org.springframework.container;
import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.*;
import org.springframework.xml.XmlParser;
import java.io.File;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* @Version 1.0
* @Author xiaohugg
* @Description ClassPathXmlApplicationContext
* @Date 2023/11/29 11:27
**/
public class ClassPathXmlApplicationContext {
private static final Logger logger = LogManager.getLogger(ClassPathXmlApplicationContext.class);
/**
* spring的ioc名字作为key
*/
private final Map<String, Object> iocNameContainer = new ConcurrentHashMap<>();
/**
* spring的class作为key
*/
private final Map<Class<?>, Object> iocClassContainer = new ConcurrentHashMap<>();
/**
* 根据接口,获取接口下的实现类
* 类似 context.getBean(UserService.class)
*/
private final Map<Class<?>, List<Object>> iocInterfacesContainer = new ConcurrentHashMap<>();
private final Set<String> classFiles = new HashSet<>();
private final String xmlPath;
public ClassPathXmlApplicationContext(String xmlPath) {
this.xmlPath = xmlPath;
refresh();
}
@SneakyThrows
private void refresh() {
//解析componentScanPath 包扫描路径
String componentScanPath = XmlParser.parse(xmlPath);
//获取包扫描路径的class文件路径
File file = findClassPath(componentScanPath);
//获取.class文件结尾的包全路径名
findClassFiles(file, componentScanPath, classFiles);
//反射
newInstance(classFiles);
//实现对象的属性的依赖注入
doDI();
logger.fatal("iocNameContainer {}", iocNameContainer);
logger.fatal("iocClassContainer {}", iocClassContainer);
logger.fatal("iocInterfacesContainer {}", iocInterfacesContainer);
}
private void doDI() {
Set<Map.Entry<Class<?>, Object>> entries = iocClassContainer.entrySet();
entries.forEach(it -> {
Class<?> aClass = it.getKey();
Field[] declaredFields = aClass.getDeclaredFields();
Set<Field> hasAutowiredField = Arrays.stream(declaredFields).filter(field -> field.isAnnotationPresent(Autowired.class)).collect(Collectors.toSet());
hasAutowiredField.forEach(field -> {
//依赖注入属性
Autowired annotation = field.getAnnotation(Autowired.class);
String value = annotation.value();
Object bean;
if ("".equals(value)) {
//默认按类型获取
Class<?> type = field.getType();
bean = getBean(type);
if (Objects.isNull(bean)) {
throw new IllegalStateException("获取不到 bean: " + type.getName());
}
} else {
//按用户填写的beanName获取
bean = iocNameContainer.getOrDefault(value, new IllegalArgumentException("找不到beanName: " + value));
}
try {
field.setAccessible(true);
field.set(iocClassContainer.get(aClass), bean);
} catch (IllegalAccessException e) {
logger.error("属性注入失败 {}", e.getMessage());
}
});
});
}
private static File findClassPath(String componentScanPath) {
String path = Objects.requireNonNull(Thread.currentThread().getContextClassLoader().getResource("")).getPath();
String url = path + componentScanPath.replace(".", File.separator);
// windows环境去除路径前面的 '/'
if (System.getProperty("os.name").toLowerCase().contains("win")) {
url = url.replaceFirst("/", "");
}
if (url.contains("test-classes")) {
url = url.replace("test-classes", "classes");
}
return new File(url);
}
public static String getBeanName(Class<?> c) {
try {
Annotation annotation = c.getAnnotations()[0];
Method valueMethod = annotation.annotationType().getDeclaredMethod("value");
String value = (String) valueMethod.invoke(annotation);
if (value != null && !value.isEmpty()) {
return value;
}
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
// 处理异常: 可能是注解没有value()方法,或者其他反射调用错误
logger.error("获取beanName 失败 {}", e.getMessage());
}
//没指定beanName 默认用类型首字母小写
return Character.toLowerCase(c.getSimpleName().charAt(0)) + c.getSimpleName().substring(1);
}
private void putIoc(Class<?>[] interfaces, Object instance, String beanName, Class<?> c) {
for (Class<?> anInterface : interfaces) {
iocInterfacesContainer.computeIfAbsent(anInterface, k -> new ArrayList<>()).add(instance);
}
iocNameContainer.compute(beanName, (key, value) -> {
if (value != null) {
throw new IllegalStateException("Bean with name '" + beanName + "' already exists.");
}
return instance;
});
iocClassContainer.compute(c, (key, value) -> {
if (value != null) {
throw new IllegalStateException("Bean with class name '" + c.getSimpleName() + "' already exists.");
}
return instance;
});
}
public Object getBean(String beanName) {
return iocNameContainer.getOrDefault(beanName, null);
}
public <T> T getBean(Class<T> clazz) {
//首先根据class获取,获取不到再通过接口获取
if (iocClassContainer.containsKey(clazz)) {
return clazz.cast(iocClassContainer.get(clazz));
}
List<Object> computed = iocInterfacesContainer.compute(clazz, (key, value) -> {
if (value == null || value.isEmpty()) {
return null;
}
if (value.size() > 1) {
throw new IllegalArgumentException("只能获取到一个bean 但是获取到了 " + value.size() + "个相同类型的bean");
}
return value;
});
return computed == null ? null : clazz.cast(computed.get(0));
}
private void newInstance(Set<String> classFiles) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
for (String classFile : classFiles) {
try {
classFile = classFile.replace(File.separator, ".").replace(".class", "");
Class<?> c = Class.forName(classFile);
Annotation[] annotations = new Annotation[]{c.getAnnotation(Component.class), c.getAnnotation(Controller.class),
c.getAnnotation(Service.class), c.getAnnotation(Repository.class)};
if (Arrays.stream(annotations).anyMatch(Objects::nonNull)) {
String beanName = getBeanName(c);
Object instance = c.newInstance();
Class<?>[] interfaces = c.getInterfaces();
putIoc(interfaces, instance, beanName, c);
}
} catch (Exception e) {
logger.error("构造bean失败 失败原因 {}", e.getMessage());
throw e;
}
}
}
private void findClassFiles(File classFiles, String componentScanPath, Set<String> classNameList) {
File[] files = classFiles.listFiles();
if (files != null) {
for (File file : files) {
if (file.isFile() && file.getName().endsWith(".class")) {
// 如果是.class文件,添加到列表
String fullPath = file.getAbsolutePath();
int index = fullPath.indexOf(componentScanPath.replace(".", File.separator));
if (index != -1) {
String filePath = fullPath.substring(index);
classNameList.add(filePath);
}
} else if (file.isDirectory()) {
// 如果是目录,递归调用
findClassFiles(file, componentScanPath, classNameList);
}
}
}
}
}
注解标识
package org.springframework.stereotype;
import java.lang.annotation.*;
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
String value() default "";
}
package org.springframework.stereotype;
import java.lang.annotation.*;
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Controller {
String value() default "";
}
package org.springframework.stereotype;
import java.lang.annotation.*;
/**
* @Version 1.0
* @Author xiaohugg
* @Description Repository
* @Date 2023/11/30 11:50
**/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Repository {
String value() default "";
}
package org.springframework.stereotype;
import java.lang.annotation.*;
/**
* @Version 1.0
* @Author xiaohugg
* @Description Service
* @Date 2023/11/29 11:14
**/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Service {
String value() default "";
}
!!! 每个注解都有一个 value 的方法,代表 beanName 如果用户填了,则根据用户填写的 value 获取 bean,否则则根据类的名称获取,首字母小写
解析 XML 的 parse
package org.springframework.xml;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.dom4j.Attribute;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import java.io.InputStream;
/**
* @Version 1.0
* @Author xiaohugg
* @Description XmlParser
* @Date 2023/11/29 11:47
**/
public class XmlParser {
private static final Logger logger = LogManager.getLogger(XmlParser.class);
private XmlParser() {
}
public static String parse(String path) {
try (InputStream inputStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(path)) {
SAXReader saxReader = SAXReader.createDefault();
Document document = saxReader.read(inputStream);
Element rootElement = document.getRootElement();
Element element = rootElement.element("component-scan");
Attribute basePackage = element.attribute("base-package");
return basePackage.getText();
} catch (Exception e) {
logger.error("解析xml失败 {}",e.getMessage());
throw new IllegalArgumentException("解析错误");
}
}
}
<span class="ne-text">applicationContext.xml</span>
填写注解扫描包路径
<?xml version="1.0" encoding="UTF-8" ?>
<beans>
<!--扫描包-->
<component-scan base-package="com.xiaohu"/>
</beans>
测试
package com.xiaohu.springioc;
import com.xiaohu.springioc.controller.EmployeesController;
import org.springframework.container.ClassPathXmlApplicationContext;
/**
* @Version 1.0
* @Author huqiang
* @Description MainTest
* @Date 2023/11/29 11:44
**/
public class MainTest {
public static void main(String[] args) {
ClassPathXmlApplicationContext classPathXmlApplicationContext = new ClassPathXmlApplicationContext("applicationContext.xml");
//通过类型获取
EmployeesController employeesController = classPathXmlApplicationContext.getBean(EmployeesController.class);
employeesController.findEmployees();
System.out.println("======================================================");
//通过名称获取
EmployeesController bean = (EmployeesController)classPathXmlApplicationContext.getBean("oc");
bean.findEmployees();
bean.selectById(1);
}
}
可以看到相关的类已经构造好了
注入两个相同类型的 bean
往往我们在使用 Spring 的时候,注入接口,其下有多个实现类,往往会提示注入多个 bean,在这次代码中,也实现了
public <T> T getBean(Class<T> clazz) {
//首先根据class获取,获取不到再通过接口获取
if (iocClassContainer.containsKey(clazz)) {
return clazz.cast(iocClassContainer.get(clazz));
}
List<Object> computed = iocInterfacesContainer.compute(clazz, (key, value) -> {
if (value == null || value.isEmpty()) {
return null;
}
if (value.size() > 1) {
throw new IllegalArgumentException("只能获取到一个bean 但是获取到了 " + value.size() + "个相同类型的bean");
}
return value;
});
return computed == null ? null : clazz.cast(computed.get(0));
}
比如 我现在有个 EmployeesService
下面有两个实现类 EmployeesServiceImpl
、EmployeesService1
如果我在 controller 层,不根据 beanName 获取,则根据 class 获取和接口获取,就会出现找到多个 bean 异常
根据 bean 名称获取
在 Autowored 注解填充需要注入的 beanName
可以看到注入的是 impl
结论
从这个简单的代码中,基本实现了一个简单的 IoC 控制反转,实现了属性实例的传递,这个并不能解决循环依赖的问题,以及 Spring 三级缓存扩展,bean 的生命周期,相关前置后置处理扩展点等等,所以相当于学习了解的作用,后续再写写 Spring aop 另外一个核心